1# regression test for SAX 2.0
2# $Id$
3
4from xml.sax import make_parser, ContentHandler, \
5                    SAXException, SAXReaderNotAvailable, SAXParseException
6import unittest
7try:
8    make_parser()
9except SAXReaderNotAvailable:
10    # don't try to test this module if we cannot create a parser
11    raise unittest.SkipTest("no XML parsers available")
12from xml.sax.saxutils import XMLGenerator, escape, unescape, quoteattr, \
13                             XMLFilterBase, prepare_input_source
14from xml.sax.expatreader import create_parser
15from xml.sax.handler import feature_namespaces
16from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
17from io import BytesIO, StringIO
18import codecs
19import gc
20import os.path
21import shutil
22from test import support
23from test.support import findfile, run_unittest, TESTFN
24
25TEST_XMLFILE = findfile("test.xml", subdir="xmltestdata")
26TEST_XMLFILE_OUT = findfile("test.xml.out", subdir="xmltestdata")
27try:
28    TEST_XMLFILE.encode("utf-8")
29    TEST_XMLFILE_OUT.encode("utf-8")
30except UnicodeEncodeError:
31    raise unittest.SkipTest("filename is not encodable to utf8")
32
33supports_nonascii_filenames = True
34if not os.path.supports_unicode_filenames:
35    try:
36        support.TESTFN_UNICODE.encode(support.TESTFN_ENCODING)
37    except (UnicodeError, TypeError):
38        # Either the file system encoding is None, or the file name
39        # cannot be encoded in the file system encoding.
40        supports_nonascii_filenames = False
41requires_nonascii_filenames = unittest.skipUnless(
42        supports_nonascii_filenames,
43        'Requires non-ascii filenames support')
44
45ns_uri = "http://www.python.org/xml-ns/saxtest/"
46
47class XmlTestBase(unittest.TestCase):
48    def verify_empty_attrs(self, attrs):
49        self.assertRaises(KeyError, attrs.getValue, "attr")
50        self.assertRaises(KeyError, attrs.getValueByQName, "attr")
51        self.assertRaises(KeyError, attrs.getNameByQName, "attr")
52        self.assertRaises(KeyError, attrs.getQNameByName, "attr")
53        self.assertRaises(KeyError, attrs.__getitem__, "attr")
54        self.assertEqual(attrs.getLength(), 0)
55        self.assertEqual(attrs.getNames(), [])
56        self.assertEqual(attrs.getQNames(), [])
57        self.assertEqual(len(attrs), 0)
58        self.assertNotIn("attr", attrs)
59        self.assertEqual(list(attrs.keys()), [])
60        self.assertEqual(attrs.get("attrs"), None)
61        self.assertEqual(attrs.get("attrs", 25), 25)
62        self.assertEqual(list(attrs.items()), [])
63        self.assertEqual(list(attrs.values()), [])
64
65    def verify_empty_nsattrs(self, attrs):
66        self.assertRaises(KeyError, attrs.getValue, (ns_uri, "attr"))
67        self.assertRaises(KeyError, attrs.getValueByQName, "ns:attr")
68        self.assertRaises(KeyError, attrs.getNameByQName, "ns:attr")
69        self.assertRaises(KeyError, attrs.getQNameByName, (ns_uri, "attr"))
70        self.assertRaises(KeyError, attrs.__getitem__, (ns_uri, "attr"))
71        self.assertEqual(attrs.getLength(), 0)
72        self.assertEqual(attrs.getNames(), [])
73        self.assertEqual(attrs.getQNames(), [])
74        self.assertEqual(len(attrs), 0)
75        self.assertNotIn((ns_uri, "attr"), attrs)
76        self.assertEqual(list(attrs.keys()), [])
77        self.assertEqual(attrs.get((ns_uri, "attr")), None)
78        self.assertEqual(attrs.get((ns_uri, "attr"), 25), 25)
79        self.assertEqual(list(attrs.items()), [])
80        self.assertEqual(list(attrs.values()), [])
81
82    def verify_attrs_wattr(self, attrs):
83        self.assertEqual(attrs.getLength(), 1)
84        self.assertEqual(attrs.getNames(), ["attr"])
85        self.assertEqual(attrs.getQNames(), ["attr"])
86        self.assertEqual(len(attrs), 1)
87        self.assertIn("attr", attrs)
88        self.assertEqual(list(attrs.keys()), ["attr"])
89        self.assertEqual(attrs.get("attr"), "val")
90        self.assertEqual(attrs.get("attr", 25), "val")
91        self.assertEqual(list(attrs.items()), [("attr", "val")])
92        self.assertEqual(list(attrs.values()), ["val"])
93        self.assertEqual(attrs.getValue("attr"), "val")
94        self.assertEqual(attrs.getValueByQName("attr"), "val")
95        self.assertEqual(attrs.getNameByQName("attr"), "attr")
96        self.assertEqual(attrs["attr"], "val")
97        self.assertEqual(attrs.getQNameByName("attr"), "attr")
98
99
100def xml_str(doc, encoding=None):
101    if encoding is None:
102        return doc
103    return '<?xml version="1.0" encoding="%s"?>\n%s' % (encoding, doc)
104
105def xml_bytes(doc, encoding, decl_encoding=...):
106    if decl_encoding is ...:
107        decl_encoding = encoding
108    return xml_str(doc, decl_encoding).encode(encoding, 'xmlcharrefreplace')
109
110def make_xml_file(doc, encoding, decl_encoding=...):
111    if decl_encoding is ...:
112        decl_encoding = encoding
113    with open(TESTFN, 'w', encoding=encoding, errors='xmlcharrefreplace') as f:
114        f.write(xml_str(doc, decl_encoding))
115
116
117class ParseTest(unittest.TestCase):
118    data = '<money value="$\xa3\u20ac\U0001017b">$\xa3\u20ac\U0001017b</money>'
119
120    def tearDown(self):
121        support.unlink(TESTFN)
122
123    def check_parse(self, f):
124        from xml.sax import parse
125        result = StringIO()
126        parse(f, XMLGenerator(result, 'utf-8'))
127        self.assertEqual(result.getvalue(), xml_str(self.data, 'utf-8'))
128
129    def test_parse_text(self):
130        encodings = ('us-ascii', 'iso-8859-1', 'utf-8',
131                     'utf-16', 'utf-16le', 'utf-16be')
132        for encoding in encodings:
133            self.check_parse(StringIO(xml_str(self.data, encoding)))
134            make_xml_file(self.data, encoding)
135            with open(TESTFN, 'r', encoding=encoding) as f:
136                self.check_parse(f)
137            self.check_parse(StringIO(self.data))
138            make_xml_file(self.data, encoding, None)
139            with open(TESTFN, 'r', encoding=encoding) as f:
140                self.check_parse(f)
141
142    def test_parse_bytes(self):
143        # UTF-8 is default encoding, US-ASCII is compatible with UTF-8,
144        # UTF-16 is autodetected
145        encodings = ('us-ascii', 'utf-8', 'utf-16', 'utf-16le', 'utf-16be')
146        for encoding in encodings:
147            self.check_parse(BytesIO(xml_bytes(self.data, encoding)))
148            make_xml_file(self.data, encoding)
149            self.check_parse(TESTFN)
150            with open(TESTFN, 'rb') as f:
151                self.check_parse(f)
152            self.check_parse(BytesIO(xml_bytes(self.data, encoding, None)))
153            make_xml_file(self.data, encoding, None)
154            self.check_parse(TESTFN)
155            with open(TESTFN, 'rb') as f:
156                self.check_parse(f)
157        # accept UTF-8 with BOM
158        self.check_parse(BytesIO(xml_bytes(self.data, 'utf-8-sig', 'utf-8')))
159        make_xml_file(self.data, 'utf-8-sig', 'utf-8')
160        self.check_parse(TESTFN)
161        with open(TESTFN, 'rb') as f:
162            self.check_parse(f)
163        self.check_parse(BytesIO(xml_bytes(self.data, 'utf-8-sig', None)))
164        make_xml_file(self.data, 'utf-8-sig', None)
165        self.check_parse(TESTFN)
166        with open(TESTFN, 'rb') as f:
167            self.check_parse(f)
168        # accept data with declared encoding
169        self.check_parse(BytesIO(xml_bytes(self.data, 'iso-8859-1')))
170        make_xml_file(self.data, 'iso-8859-1')
171        self.check_parse(TESTFN)
172        with open(TESTFN, 'rb') as f:
173            self.check_parse(f)
174        # fail on non-UTF-8 incompatible data without declared encoding
175        with self.assertRaises(SAXException):
176            self.check_parse(BytesIO(xml_bytes(self.data, 'iso-8859-1', None)))
177        make_xml_file(self.data, 'iso-8859-1', None)
178        with support.check_warnings(('unclosed file', ResourceWarning)):
179            # XXX Failed parser leaks an opened file.
180            with self.assertRaises(SAXException):
181                self.check_parse(TESTFN)
182            # Collect leaked file.
183            gc.collect()
184        with open(TESTFN, 'rb') as f:
185            with self.assertRaises(SAXException):
186                self.check_parse(f)
187
188    def test_parse_InputSource(self):
189        # accept data without declared but with explicitly specified encoding
190        make_xml_file(self.data, 'iso-8859-1', None)
191        with open(TESTFN, 'rb') as f:
192            input = InputSource()
193            input.setByteStream(f)
194            input.setEncoding('iso-8859-1')
195            self.check_parse(input)
196
197    def check_parseString(self, s):
198        from xml.sax import parseString
199        result = StringIO()
200        parseString(s, XMLGenerator(result, 'utf-8'))
201        self.assertEqual(result.getvalue(), xml_str(self.data, 'utf-8'))
202
203    def test_parseString_text(self):
204        encodings = ('us-ascii', 'iso-8859-1', 'utf-8',
205                     'utf-16', 'utf-16le', 'utf-16be')
206        for encoding in encodings:
207            self.check_parseString(xml_str(self.data, encoding))
208        self.check_parseString(self.data)
209
210    def test_parseString_bytes(self):
211        # UTF-8 is default encoding, US-ASCII is compatible with UTF-8,
212        # UTF-16 is autodetected
213        encodings = ('us-ascii', 'utf-8', 'utf-16', 'utf-16le', 'utf-16be')
214        for encoding in encodings:
215            self.check_parseString(xml_bytes(self.data, encoding))
216            self.check_parseString(xml_bytes(self.data, encoding, None))
217        # accept UTF-8 with BOM
218        self.check_parseString(xml_bytes(self.data, 'utf-8-sig', 'utf-8'))
219        self.check_parseString(xml_bytes(self.data, 'utf-8-sig', None))
220        # accept data with declared encoding
221        self.check_parseString(xml_bytes(self.data, 'iso-8859-1'))
222        # fail on non-UTF-8 incompatible data without declared encoding
223        with self.assertRaises(SAXException):
224            self.check_parseString(xml_bytes(self.data, 'iso-8859-1', None))
225
226class MakeParserTest(unittest.TestCase):
227    def test_make_parser2(self):
228        # Creating parsers several times in a row should succeed.
229        # Testing this because there have been failures of this kind
230        # before.
231        from xml.sax import make_parser
232        p = make_parser()
233        from xml.sax import make_parser
234        p = make_parser()
235        from xml.sax import make_parser
236        p = make_parser()
237        from xml.sax import make_parser
238        p = make_parser()
239        from xml.sax import make_parser
240        p = make_parser()
241        from xml.sax import make_parser
242        p = make_parser()
243
244
245# ===========================================================================
246#
247#   saxutils tests
248#
249# ===========================================================================
250
251class SaxutilsTest(unittest.TestCase):
252    # ===== escape
253    def test_escape_basic(self):
254        self.assertEqual(escape("Donald Duck & Co"), "Donald Duck &amp; Co")
255
256    def test_escape_all(self):
257        self.assertEqual(escape("<Donald Duck & Co>"),
258                         "&lt;Donald Duck &amp; Co&gt;")
259
260    def test_escape_extra(self):
261        self.assertEqual(escape("Hei på deg", {"å" : "&aring;"}),
262                         "Hei p&aring; deg")
263
264    # ===== unescape
265    def test_unescape_basic(self):
266        self.assertEqual(unescape("Donald Duck &amp; Co"), "Donald Duck & Co")
267
268    def test_unescape_all(self):
269        self.assertEqual(unescape("&lt;Donald Duck &amp; Co&gt;"),
270                         "<Donald Duck & Co>")
271
272    def test_unescape_extra(self):
273        self.assertEqual(unescape("Hei på deg", {"å" : "&aring;"}),
274                         "Hei p&aring; deg")
275
276    def test_unescape_amp_extra(self):
277        self.assertEqual(unescape("&amp;foo;", {"&foo;": "splat"}), "&foo;")
278
279    # ===== quoteattr
280    def test_quoteattr_basic(self):
281        self.assertEqual(quoteattr("Donald Duck & Co"),
282                         '"Donald Duck &amp; Co"')
283
284    def test_single_quoteattr(self):
285        self.assertEqual(quoteattr('Includes "double" quotes'),
286                         '\'Includes "double" quotes\'')
287
288    def test_double_quoteattr(self):
289        self.assertEqual(quoteattr("Includes 'single' quotes"),
290                         "\"Includes 'single' quotes\"")
291
292    def test_single_double_quoteattr(self):
293        self.assertEqual(quoteattr("Includes 'single' and \"double\" quotes"),
294                         "\"Includes 'single' and &quot;double&quot; quotes\"")
295
296    # ===== make_parser
297    def test_make_parser(self):
298        # Creating a parser should succeed - it should fall back
299        # to the expatreader
300        p = make_parser(['xml.parsers.no_such_parser'])
301
302
303class PrepareInputSourceTest(unittest.TestCase):
304
305    def setUp(self):
306        self.file = support.TESTFN
307        with open(self.file, "w") as tmp:
308            tmp.write("This was read from a file.")
309
310    def tearDown(self):
311        support.unlink(self.file)
312
313    def make_byte_stream(self):
314        return BytesIO(b"This is a byte stream.")
315
316    def make_character_stream(self):
317        return StringIO("This is a character stream.")
318
319    def checkContent(self, stream, content):
320        self.assertIsNotNone(stream)
321        self.assertEqual(stream.read(), content)
322        stream.close()
323
324
325    def test_character_stream(self):
326        # If the source is an InputSource with a character stream, use it.
327        src = InputSource(self.file)
328        src.setCharacterStream(self.make_character_stream())
329        prep = prepare_input_source(src)
330        self.assertIsNone(prep.getByteStream())
331        self.checkContent(prep.getCharacterStream(),
332                          "This is a character stream.")
333
334    def test_byte_stream(self):
335        # If the source is an InputSource that does not have a character
336        # stream but does have a byte stream, use the byte stream.
337        src = InputSource(self.file)
338        src.setByteStream(self.make_byte_stream())
339        prep = prepare_input_source(src)
340        self.assertIsNone(prep.getCharacterStream())
341        self.checkContent(prep.getByteStream(),
342                          b"This is a byte stream.")
343
344    def test_system_id(self):
345        # If the source is an InputSource that has neither a character
346        # stream nor a byte stream, open the system ID.
347        src = InputSource(self.file)
348        prep = prepare_input_source(src)
349        self.assertIsNone(prep.getCharacterStream())
350        self.checkContent(prep.getByteStream(),
351                          b"This was read from a file.")
352
353    def test_string(self):
354        # If the source is a string, use it as a system ID and open it.
355        prep = prepare_input_source(self.file)
356        self.assertIsNone(prep.getCharacterStream())
357        self.checkContent(prep.getByteStream(),
358                          b"This was read from a file.")
359
360    def test_binary_file(self):
361        # If the source is a binary file-like object, use it as a byte
362        # stream.
363        prep = prepare_input_source(self.make_byte_stream())
364        self.assertIsNone(prep.getCharacterStream())
365        self.checkContent(prep.getByteStream(),
366                          b"This is a byte stream.")
367
368    def test_text_file(self):
369        # If the source is a text file-like object, use it as a character
370        # stream.
371        prep = prepare_input_source(self.make_character_stream())
372        self.assertIsNone(prep.getByteStream())
373        self.checkContent(prep.getCharacterStream(),
374                          "This is a character stream.")
375
376
377# ===== XMLGenerator
378
379class XmlgenTest:
380    def test_xmlgen_basic(self):
381        result = self.ioclass()
382        gen = XMLGenerator(result)
383        gen.startDocument()
384        gen.startElement("doc", {})
385        gen.endElement("doc")
386        gen.endDocument()
387
388        self.assertEqual(result.getvalue(), self.xml("<doc></doc>"))
389
390    def test_xmlgen_basic_empty(self):
391        result = self.ioclass()
392        gen = XMLGenerator(result, short_empty_elements=True)
393        gen.startDocument()
394        gen.startElement("doc", {})
395        gen.endElement("doc")
396        gen.endDocument()
397
398        self.assertEqual(result.getvalue(), self.xml("<doc/>"))
399
400    def test_xmlgen_content(self):
401        result = self.ioclass()
402        gen = XMLGenerator(result)
403
404        gen.startDocument()
405        gen.startElement("doc", {})
406        gen.characters("huhei")
407        gen.endElement("doc")
408        gen.endDocument()
409
410        self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
411
412    def test_xmlgen_content_empty(self):
413        result = self.ioclass()
414        gen = XMLGenerator(result, short_empty_elements=True)
415
416        gen.startDocument()
417        gen.startElement("doc", {})
418        gen.characters("huhei")
419        gen.endElement("doc")
420        gen.endDocument()
421
422        self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
423
424    def test_xmlgen_pi(self):
425        result = self.ioclass()
426        gen = XMLGenerator(result)
427
428        gen.startDocument()
429        gen.processingInstruction("test", "data")
430        gen.startElement("doc", {})
431        gen.endElement("doc")
432        gen.endDocument()
433
434        self.assertEqual(result.getvalue(),
435            self.xml("<?test data?><doc></doc>"))
436
437    def test_xmlgen_content_escape(self):
438        result = self.ioclass()
439        gen = XMLGenerator(result)
440
441        gen.startDocument()
442        gen.startElement("doc", {})
443        gen.characters("<huhei&")
444        gen.endElement("doc")
445        gen.endDocument()
446
447        self.assertEqual(result.getvalue(),
448            self.xml("<doc>&lt;huhei&amp;</doc>"))
449
450    def test_xmlgen_attr_escape(self):
451        result = self.ioclass()
452        gen = XMLGenerator(result)
453
454        gen.startDocument()
455        gen.startElement("doc", {"a": '"'})
456        gen.startElement("e", {"a": "'"})
457        gen.endElement("e")
458        gen.startElement("e", {"a": "'\""})
459        gen.endElement("e")
460        gen.startElement("e", {"a": "\n\r\t"})
461        gen.endElement("e")
462        gen.endElement("doc")
463        gen.endDocument()
464
465        self.assertEqual(result.getvalue(), self.xml(
466            "<doc a='\"'><e a=\"'\"></e>"
467            "<e a=\"'&quot;\"></e>"
468            "<e a=\"&#10;&#13;&#9;\"></e></doc>"))
469
470    def test_xmlgen_encoding(self):
471        encodings = ('iso-8859-15', 'utf-8', 'utf-8-sig',
472                     'utf-16', 'utf-16be', 'utf-16le',
473                     'utf-32', 'utf-32be', 'utf-32le')
474        for encoding in encodings:
475            result = self.ioclass()
476            gen = XMLGenerator(result, encoding=encoding)
477
478            gen.startDocument()
479            gen.startElement("doc", {"a": '\u20ac'})
480            gen.characters("\u20ac")
481            gen.endElement("doc")
482            gen.endDocument()
483
484            self.assertEqual(result.getvalue(),
485                self.xml('<doc a="\u20ac">\u20ac</doc>', encoding=encoding))
486
487    def test_xmlgen_unencodable(self):
488        result = self.ioclass()
489        gen = XMLGenerator(result, encoding='ascii')
490
491        gen.startDocument()
492        gen.startElement("doc", {"a": '\u20ac'})
493        gen.characters("\u20ac")
494        gen.endElement("doc")
495        gen.endDocument()
496
497        self.assertEqual(result.getvalue(),
498            self.xml('<doc a="&#8364;">&#8364;</doc>', encoding='ascii'))
499
500    def test_xmlgen_ignorable(self):
501        result = self.ioclass()
502        gen = XMLGenerator(result)
503
504        gen.startDocument()
505        gen.startElement("doc", {})
506        gen.ignorableWhitespace(" ")
507        gen.endElement("doc")
508        gen.endDocument()
509
510        self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
511
512    def test_xmlgen_ignorable_empty(self):
513        result = self.ioclass()
514        gen = XMLGenerator(result, short_empty_elements=True)
515
516        gen.startDocument()
517        gen.startElement("doc", {})
518        gen.ignorableWhitespace(" ")
519        gen.endElement("doc")
520        gen.endDocument()
521
522        self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
523
524    def test_xmlgen_encoding_bytes(self):
525        encodings = ('iso-8859-15', 'utf-8', 'utf-8-sig',
526                     'utf-16', 'utf-16be', 'utf-16le',
527                     'utf-32', 'utf-32be', 'utf-32le')
528        for encoding in encodings:
529            result = self.ioclass()
530            gen = XMLGenerator(result, encoding=encoding)
531
532            gen.startDocument()
533            gen.startElement("doc", {"a": '\u20ac'})
534            gen.characters("\u20ac".encode(encoding))
535            gen.ignorableWhitespace(" ".encode(encoding))
536            gen.endElement("doc")
537            gen.endDocument()
538
539            self.assertEqual(result.getvalue(),
540                self.xml('<doc a="\u20ac">\u20ac </doc>', encoding=encoding))
541
542    def test_xmlgen_ns(self):
543        result = self.ioclass()
544        gen = XMLGenerator(result)
545
546        gen.startDocument()
547        gen.startPrefixMapping("ns1", ns_uri)
548        gen.startElementNS((ns_uri, "doc"), "ns1:doc", {})
549        # add an unqualified name
550        gen.startElementNS((None, "udoc"), None, {})
551        gen.endElementNS((None, "udoc"), None)
552        gen.endElementNS((ns_uri, "doc"), "ns1:doc")
553        gen.endPrefixMapping("ns1")
554        gen.endDocument()
555
556        self.assertEqual(result.getvalue(), self.xml(
557           '<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
558                                         ns_uri))
559
560    def test_xmlgen_ns_empty(self):
561        result = self.ioclass()
562        gen = XMLGenerator(result, short_empty_elements=True)
563
564        gen.startDocument()
565        gen.startPrefixMapping("ns1", ns_uri)
566        gen.startElementNS((ns_uri, "doc"), "ns1:doc", {})
567        # add an unqualified name
568        gen.startElementNS((None, "udoc"), None, {})
569        gen.endElementNS((None, "udoc"), None)
570        gen.endElementNS((ns_uri, "doc"), "ns1:doc")
571        gen.endPrefixMapping("ns1")
572        gen.endDocument()
573
574        self.assertEqual(result.getvalue(), self.xml(
575           '<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' %
576                                         ns_uri))
577
578    def test_1463026_1(self):
579        result = self.ioclass()
580        gen = XMLGenerator(result)
581
582        gen.startDocument()
583        gen.startElementNS((None, 'a'), 'a', {(None, 'b'):'c'})
584        gen.endElementNS((None, 'a'), 'a')
585        gen.endDocument()
586
587        self.assertEqual(result.getvalue(), self.xml('<a b="c"></a>'))
588
589    def test_1463026_1_empty(self):
590        result = self.ioclass()
591        gen = XMLGenerator(result, short_empty_elements=True)
592
593        gen.startDocument()
594        gen.startElementNS((None, 'a'), 'a', {(None, 'b'):'c'})
595        gen.endElementNS((None, 'a'), 'a')
596        gen.endDocument()
597
598        self.assertEqual(result.getvalue(), self.xml('<a b="c"/>'))
599
600    def test_1463026_2(self):
601        result = self.ioclass()
602        gen = XMLGenerator(result)
603
604        gen.startDocument()
605        gen.startPrefixMapping(None, 'qux')
606        gen.startElementNS(('qux', 'a'), 'a', {})
607        gen.endElementNS(('qux', 'a'), 'a')
608        gen.endPrefixMapping(None)
609        gen.endDocument()
610
611        self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"></a>'))
612
613    def test_1463026_2_empty(self):
614        result = self.ioclass()
615        gen = XMLGenerator(result, short_empty_elements=True)
616
617        gen.startDocument()
618        gen.startPrefixMapping(None, 'qux')
619        gen.startElementNS(('qux', 'a'), 'a', {})
620        gen.endElementNS(('qux', 'a'), 'a')
621        gen.endPrefixMapping(None)
622        gen.endDocument()
623
624        self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"/>'))
625
626    def test_1463026_3(self):
627        result = self.ioclass()
628        gen = XMLGenerator(result)
629
630        gen.startDocument()
631        gen.startPrefixMapping('my', 'qux')
632        gen.startElementNS(('qux', 'a'), 'a', {(None, 'b'):'c'})
633        gen.endElementNS(('qux', 'a'), 'a')
634        gen.endPrefixMapping('my')
635        gen.endDocument()
636
637        self.assertEqual(result.getvalue(),
638            self.xml('<my:a xmlns:my="qux" b="c"></my:a>'))
639
640    def test_1463026_3_empty(self):
641        result = self.ioclass()
642        gen = XMLGenerator(result, short_empty_elements=True)
643
644        gen.startDocument()
645        gen.startPrefixMapping('my', 'qux')
646        gen.startElementNS(('qux', 'a'), 'a', {(None, 'b'):'c'})
647        gen.endElementNS(('qux', 'a'), 'a')
648        gen.endPrefixMapping('my')
649        gen.endDocument()
650
651        self.assertEqual(result.getvalue(),
652            self.xml('<my:a xmlns:my="qux" b="c"/>'))
653
654    def test_5027_1(self):
655        # The xml prefix (as in xml:lang below) is reserved and bound by
656        # definition to http://www.w3.org/XML/1998/namespace.  XMLGenerator had
657        # a bug whereby a KeyError is raised because this namespace is missing
658        # from a dictionary.
659        #
660        # This test demonstrates the bug by parsing a document.
661        test_xml = StringIO(
662            '<?xml version="1.0"?>'
663            '<a:g1 xmlns:a="http://example.com/ns">'
664             '<a:g2 xml:lang="en">Hello</a:g2>'
665            '</a:g1>')
666
667        parser = make_parser()
668        parser.setFeature(feature_namespaces, True)
669        result = self.ioclass()
670        gen = XMLGenerator(result)
671        parser.setContentHandler(gen)
672        parser.parse(test_xml)
673
674        self.assertEqual(result.getvalue(),
675                         self.xml(
676                         '<a:g1 xmlns:a="http://example.com/ns">'
677                          '<a:g2 xml:lang="en">Hello</a:g2>'
678                         '</a:g1>'))
679
680    def test_5027_2(self):
681        # The xml prefix (as in xml:lang below) is reserved and bound by
682        # definition to http://www.w3.org/XML/1998/namespace.  XMLGenerator had
683        # a bug whereby a KeyError is raised because this namespace is missing
684        # from a dictionary.
685        #
686        # This test demonstrates the bug by direct manipulation of the
687        # XMLGenerator.
688        result = self.ioclass()
689        gen = XMLGenerator(result)
690
691        gen.startDocument()
692        gen.startPrefixMapping('a', 'http://example.com/ns')
693        gen.startElementNS(('http://example.com/ns', 'g1'), 'g1', {})
694        lang_attr = {('http://www.w3.org/XML/1998/namespace', 'lang'): 'en'}
695        gen.startElementNS(('http://example.com/ns', 'g2'), 'g2', lang_attr)
696        gen.characters('Hello')
697        gen.endElementNS(('http://example.com/ns', 'g2'), 'g2')
698        gen.endElementNS(('http://example.com/ns', 'g1'), 'g1')
699        gen.endPrefixMapping('a')
700        gen.endDocument()
701
702        self.assertEqual(result.getvalue(),
703                         self.xml(
704                         '<a:g1 xmlns:a="http://example.com/ns">'
705                          '<a:g2 xml:lang="en">Hello</a:g2>'
706                         '</a:g1>'))
707
708    def test_no_close_file(self):
709        result = self.ioclass()
710        def func(out):
711            gen = XMLGenerator(out)
712            gen.startDocument()
713            gen.startElement("doc", {})
714        func(result)
715        self.assertFalse(result.closed)
716
717    def test_xmlgen_fragment(self):
718        result = self.ioclass()
719        gen = XMLGenerator(result)
720
721        # Don't call gen.startDocument()
722        gen.startElement("foo", {"a": "1.0"})
723        gen.characters("Hello")
724        gen.endElement("foo")
725        gen.startElement("bar", {"b": "2.0"})
726        gen.endElement("bar")
727        # Don't call gen.endDocument()
728
729        self.assertEqual(result.getvalue(),
730            self.xml('<foo a="1.0">Hello</foo><bar b="2.0"></bar>')[len(self.xml('')):])
731
732class StringXmlgenTest(XmlgenTest, unittest.TestCase):
733    ioclass = StringIO
734
735    def xml(self, doc, encoding='iso-8859-1'):
736        return '<?xml version="1.0" encoding="%s"?>\n%s' % (encoding, doc)
737
738    test_xmlgen_unencodable = None
739
740class BytesXmlgenTest(XmlgenTest, unittest.TestCase):
741    ioclass = BytesIO
742
743    def xml(self, doc, encoding='iso-8859-1'):
744        return ('<?xml version="1.0" encoding="%s"?>\n%s' %
745                (encoding, doc)).encode(encoding, 'xmlcharrefreplace')
746
747class WriterXmlgenTest(BytesXmlgenTest):
748    class ioclass(list):
749        write = list.append
750        closed = False
751
752        def seekable(self):
753            return True
754
755        def tell(self):
756            # return 0 at start and not 0 after start
757            return len(self)
758
759        def getvalue(self):
760            return b''.join(self)
761
762class StreamWriterXmlgenTest(XmlgenTest, unittest.TestCase):
763    def ioclass(self):
764        raw = BytesIO()
765        writer = codecs.getwriter('ascii')(raw, 'xmlcharrefreplace')
766        writer.getvalue = raw.getvalue
767        return writer
768
769    def xml(self, doc, encoding='iso-8859-1'):
770        return ('<?xml version="1.0" encoding="%s"?>\n%s' %
771                (encoding, doc)).encode('ascii', 'xmlcharrefreplace')
772
773class StreamReaderWriterXmlgenTest(XmlgenTest, unittest.TestCase):
774    fname = support.TESTFN + '-codecs'
775
776    def ioclass(self):
777        writer = codecs.open(self.fname, 'w', encoding='ascii',
778                             errors='xmlcharrefreplace', buffering=0)
779        def cleanup():
780            writer.close()
781            support.unlink(self.fname)
782        self.addCleanup(cleanup)
783        def getvalue():
784            # Windows will not let use reopen without first closing
785            writer.close()
786            with open(writer.name, 'rb') as f:
787                return f.read()
788        writer.getvalue = getvalue
789        return writer
790
791    def xml(self, doc, encoding='iso-8859-1'):
792        return ('<?xml version="1.0" encoding="%s"?>\n%s' %
793                (encoding, doc)).encode('ascii', 'xmlcharrefreplace')
794
795start = b'<?xml version="1.0" encoding="iso-8859-1"?>\n'
796
797
798class XMLFilterBaseTest(unittest.TestCase):
799    def test_filter_basic(self):
800        result = BytesIO()
801        gen = XMLGenerator(result)
802        filter = XMLFilterBase()
803        filter.setContentHandler(gen)
804
805        filter.startDocument()
806        filter.startElement("doc", {})
807        filter.characters("content")
808        filter.ignorableWhitespace(" ")
809        filter.endElement("doc")
810        filter.endDocument()
811
812        self.assertEqual(result.getvalue(), start + b"<doc>content </doc>")
813
814# ===========================================================================
815#
816#   expatreader tests
817#
818# ===========================================================================
819
820with open(TEST_XMLFILE_OUT, 'rb') as f:
821    xml_test_out = f.read()
822
823class ExpatReaderTest(XmlTestBase):
824
825    # ===== XMLReader support
826
827    def test_expat_binary_file(self):
828        parser = create_parser()
829        result = BytesIO()
830        xmlgen = XMLGenerator(result)
831
832        parser.setContentHandler(xmlgen)
833        with open(TEST_XMLFILE, 'rb') as f:
834            parser.parse(f)
835
836        self.assertEqual(result.getvalue(), xml_test_out)
837
838    def test_expat_text_file(self):
839        parser = create_parser()
840        result = BytesIO()
841        xmlgen = XMLGenerator(result)
842
843        parser.setContentHandler(xmlgen)
844        with open(TEST_XMLFILE, 'rt', encoding='iso-8859-1') as f:
845            parser.parse(f)
846
847        self.assertEqual(result.getvalue(), xml_test_out)
848
849    @requires_nonascii_filenames
850    def test_expat_binary_file_nonascii(self):
851        fname = support.TESTFN_UNICODE
852        shutil.copyfile(TEST_XMLFILE, fname)
853        self.addCleanup(support.unlink, fname)
854
855        parser = create_parser()
856        result = BytesIO()
857        xmlgen = XMLGenerator(result)
858
859        parser.setContentHandler(xmlgen)
860        parser.parse(open(fname, 'rb'))
861
862        self.assertEqual(result.getvalue(), xml_test_out)
863
864    def test_expat_binary_file_bytes_name(self):
865        fname = os.fsencode(TEST_XMLFILE)
866        parser = create_parser()
867        result = BytesIO()
868        xmlgen = XMLGenerator(result)
869
870        parser.setContentHandler(xmlgen)
871        with open(fname, 'rb') as f:
872            parser.parse(f)
873
874        self.assertEqual(result.getvalue(), xml_test_out)
875
876    def test_expat_binary_file_int_name(self):
877        parser = create_parser()
878        result = BytesIO()
879        xmlgen = XMLGenerator(result)
880
881        parser.setContentHandler(xmlgen)
882        with open(TEST_XMLFILE, 'rb') as f:
883            with open(f.fileno(), 'rb', closefd=False) as f2:
884                parser.parse(f2)
885
886        self.assertEqual(result.getvalue(), xml_test_out)
887
888    # ===== DTDHandler support
889
890    class TestDTDHandler:
891
892        def __init__(self):
893            self._notations = []
894            self._entities  = []
895
896        def notationDecl(self, name, publicId, systemId):
897            self._notations.append((name, publicId, systemId))
898
899        def unparsedEntityDecl(self, name, publicId, systemId, ndata):
900            self._entities.append((name, publicId, systemId, ndata))
901
902    def test_expat_dtdhandler(self):
903        parser = create_parser()
904        handler = self.TestDTDHandler()
905        parser.setDTDHandler(handler)
906
907        parser.feed('<!DOCTYPE doc [\n')
908        parser.feed('  <!ENTITY img SYSTEM "expat.gif" NDATA GIF>\n')
909        parser.feed('  <!NOTATION GIF PUBLIC "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN">\n')
910        parser.feed(']>\n')
911        parser.feed('<doc></doc>')
912        parser.close()
913
914        self.assertEqual(handler._notations,
915            [("GIF", "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN", None)])
916        self.assertEqual(handler._entities, [("img", None, "expat.gif", "GIF")])
917
918    # ===== EntityResolver support
919
920    class TestEntityResolver:
921
922        def resolveEntity(self, publicId, systemId):
923            inpsrc = InputSource()
924            inpsrc.setByteStream(BytesIO(b"<entity/>"))
925            return inpsrc
926
927    def test_expat_entityresolver(self):
928        parser = create_parser()
929        parser.setEntityResolver(self.TestEntityResolver())
930        result = BytesIO()
931        parser.setContentHandler(XMLGenerator(result))
932
933        parser.feed('<!DOCTYPE doc [\n')
934        parser.feed('  <!ENTITY test SYSTEM "whatever">\n')
935        parser.feed(']>\n')
936        parser.feed('<doc>&test;</doc>')
937        parser.close()
938
939        self.assertEqual(result.getvalue(), start +
940                         b"<doc><entity></entity></doc>")
941
942    # ===== Attributes support
943
944    class AttrGatherer(ContentHandler):
945
946        def startElement(self, name, attrs):
947            self._attrs = attrs
948
949        def startElementNS(self, name, qname, attrs):
950            self._attrs = attrs
951
952    def test_expat_attrs_empty(self):
953        parser = create_parser()
954        gather = self.AttrGatherer()
955        parser.setContentHandler(gather)
956
957        parser.feed("<doc/>")
958        parser.close()
959
960        self.verify_empty_attrs(gather._attrs)
961
962    def test_expat_attrs_wattr(self):
963        parser = create_parser()
964        gather = self.AttrGatherer()
965        parser.setContentHandler(gather)
966
967        parser.feed("<doc attr='val'/>")
968        parser.close()
969
970        self.verify_attrs_wattr(gather._attrs)
971
972    def test_expat_nsattrs_empty(self):
973        parser = create_parser(1)
974        gather = self.AttrGatherer()
975        parser.setContentHandler(gather)
976
977        parser.feed("<doc/>")
978        parser.close()
979
980        self.verify_empty_nsattrs(gather._attrs)
981
982    def test_expat_nsattrs_wattr(self):
983        parser = create_parser(1)
984        gather = self.AttrGatherer()
985        parser.setContentHandler(gather)
986
987        parser.feed("<doc xmlns:ns='%s' ns:attr='val'/>" % ns_uri)
988        parser.close()
989
990        attrs = gather._attrs
991
992        self.assertEqual(attrs.getLength(), 1)
993        self.assertEqual(attrs.getNames(), [(ns_uri, "attr")])
994        self.assertTrue((attrs.getQNames() == [] or
995                         attrs.getQNames() == ["ns:attr"]))
996        self.assertEqual(len(attrs), 1)
997        self.assertIn((ns_uri, "attr"), attrs)
998        self.assertEqual(attrs.get((ns_uri, "attr")), "val")
999        self.assertEqual(attrs.get((ns_uri, "attr"), 25), "val")
1000        self.assertEqual(list(attrs.items()), [((ns_uri, "attr"), "val")])
1001        self.assertEqual(list(attrs.values()), ["val"])
1002        self.assertEqual(attrs.getValue((ns_uri, "attr")), "val")
1003        self.assertEqual(attrs[(ns_uri, "attr")], "val")
1004
1005    # ===== InputSource support
1006
1007    def test_expat_inpsource_filename(self):
1008        parser = create_parser()
1009        result = BytesIO()
1010        xmlgen = XMLGenerator(result)
1011
1012        parser.setContentHandler(xmlgen)
1013        parser.parse(TEST_XMLFILE)
1014
1015        self.assertEqual(result.getvalue(), xml_test_out)
1016
1017    def test_expat_inpsource_sysid(self):
1018        parser = create_parser()
1019        result = BytesIO()
1020        xmlgen = XMLGenerator(result)
1021
1022        parser.setContentHandler(xmlgen)
1023        parser.parse(InputSource(TEST_XMLFILE))
1024
1025        self.assertEqual(result.getvalue(), xml_test_out)
1026
1027    @requires_nonascii_filenames
1028    def test_expat_inpsource_sysid_nonascii(self):
1029        fname = support.TESTFN_UNICODE
1030        shutil.copyfile(TEST_XMLFILE, fname)
1031        self.addCleanup(support.unlink, fname)
1032
1033        parser = create_parser()
1034        result = BytesIO()
1035        xmlgen = XMLGenerator(result)
1036
1037        parser.setContentHandler(xmlgen)
1038        parser.parse(InputSource(fname))
1039
1040        self.assertEqual(result.getvalue(), xml_test_out)
1041
1042    def test_expat_inpsource_byte_stream(self):
1043        parser = create_parser()
1044        result = BytesIO()
1045        xmlgen = XMLGenerator(result)
1046
1047        parser.setContentHandler(xmlgen)
1048        inpsrc = InputSource()
1049        with open(TEST_XMLFILE, 'rb') as f:
1050            inpsrc.setByteStream(f)
1051            parser.parse(inpsrc)
1052
1053        self.assertEqual(result.getvalue(), xml_test_out)
1054
1055    def test_expat_inpsource_character_stream(self):
1056        parser = create_parser()
1057        result = BytesIO()
1058        xmlgen = XMLGenerator(result)
1059
1060        parser.setContentHandler(xmlgen)
1061        inpsrc = InputSource()
1062        with open(TEST_XMLFILE, 'rt', encoding='iso-8859-1') as f:
1063            inpsrc.setCharacterStream(f)
1064            parser.parse(inpsrc)
1065
1066        self.assertEqual(result.getvalue(), xml_test_out)
1067
1068    # ===== IncrementalParser support
1069
1070    def test_expat_incremental(self):
1071        result = BytesIO()
1072        xmlgen = XMLGenerator(result)
1073        parser = create_parser()
1074        parser.setContentHandler(xmlgen)
1075
1076        parser.feed("<doc>")
1077        parser.feed("</doc>")
1078        parser.close()
1079
1080        self.assertEqual(result.getvalue(), start + b"<doc></doc>")
1081
1082    def test_expat_incremental_reset(self):
1083        result = BytesIO()
1084        xmlgen = XMLGenerator(result)
1085        parser = create_parser()
1086        parser.setContentHandler(xmlgen)
1087
1088        parser.feed("<doc>")
1089        parser.feed("text")
1090
1091        result = BytesIO()
1092        xmlgen = XMLGenerator(result)
1093        parser.setContentHandler(xmlgen)
1094        parser.reset()
1095
1096        parser.feed("<doc>")
1097        parser.feed("text")
1098        parser.feed("</doc>")
1099        parser.close()
1100
1101        self.assertEqual(result.getvalue(), start + b"<doc>text</doc>")
1102
1103    # ===== Locator support
1104
1105    def test_expat_locator_noinfo(self):
1106        result = BytesIO()
1107        xmlgen = XMLGenerator(result)
1108        parser = create_parser()
1109        parser.setContentHandler(xmlgen)
1110
1111        parser.feed("<doc>")
1112        parser.feed("</doc>")
1113        parser.close()
1114
1115        self.assertEqual(parser.getSystemId(), None)
1116        self.assertEqual(parser.getPublicId(), None)
1117        self.assertEqual(parser.getLineNumber(), 1)
1118
1119    def test_expat_locator_withinfo(self):
1120        result = BytesIO()
1121        xmlgen = XMLGenerator(result)
1122        parser = create_parser()
1123        parser.setContentHandler(xmlgen)
1124        parser.parse(TEST_XMLFILE)
1125
1126        self.assertEqual(parser.getSystemId(), TEST_XMLFILE)
1127        self.assertEqual(parser.getPublicId(), None)
1128
1129    @requires_nonascii_filenames
1130    def test_expat_locator_withinfo_nonascii(self):
1131        fname = support.TESTFN_UNICODE
1132        shutil.copyfile(TEST_XMLFILE, fname)
1133        self.addCleanup(support.unlink, fname)
1134
1135        result = BytesIO()
1136        xmlgen = XMLGenerator(result)
1137        parser = create_parser()
1138        parser.setContentHandler(xmlgen)
1139        parser.parse(fname)
1140
1141        self.assertEqual(parser.getSystemId(), fname)
1142        self.assertEqual(parser.getPublicId(), None)
1143
1144
1145# ===========================================================================
1146#
1147#   error reporting
1148#
1149# ===========================================================================
1150
1151class ErrorReportingTest(unittest.TestCase):
1152    def test_expat_inpsource_location(self):
1153        parser = create_parser()
1154        parser.setContentHandler(ContentHandler()) # do nothing
1155        source = InputSource()
1156        source.setByteStream(BytesIO(b"<foo bar foobar>"))   #ill-formed
1157        name = "a file name"
1158        source.setSystemId(name)
1159        try:
1160            parser.parse(source)
1161            self.fail()
1162        except SAXException as e:
1163            self.assertEqual(e.getSystemId(), name)
1164
1165    def test_expat_incomplete(self):
1166        parser = create_parser()
1167        parser.setContentHandler(ContentHandler()) # do nothing
1168        self.assertRaises(SAXParseException, parser.parse, StringIO("<foo>"))
1169        self.assertEqual(parser.getColumnNumber(), 5)
1170        self.assertEqual(parser.getLineNumber(), 1)
1171
1172    def test_sax_parse_exception_str(self):
1173        # pass various values from a locator to the SAXParseException to
1174        # make sure that the __str__() doesn't fall apart when None is
1175        # passed instead of an integer line and column number
1176        #
1177        # use "normal" values for the locator:
1178        str(SAXParseException("message", None,
1179                              self.DummyLocator(1, 1)))
1180        # use None for the line number:
1181        str(SAXParseException("message", None,
1182                              self.DummyLocator(None, 1)))
1183        # use None for the column number:
1184        str(SAXParseException("message", None,
1185                              self.DummyLocator(1, None)))
1186        # use None for both:
1187        str(SAXParseException("message", None,
1188                              self.DummyLocator(None, None)))
1189
1190    class DummyLocator:
1191        def __init__(self, lineno, colno):
1192            self._lineno = lineno
1193            self._colno = colno
1194
1195        def getPublicId(self):
1196            return "pubid"
1197
1198        def getSystemId(self):
1199            return "sysid"
1200
1201        def getLineNumber(self):
1202            return self._lineno
1203
1204        def getColumnNumber(self):
1205            return self._colno
1206
1207# ===========================================================================
1208#
1209#   xmlreader tests
1210#
1211# ===========================================================================
1212
1213class XmlReaderTest(XmlTestBase):
1214
1215    # ===== AttributesImpl
1216    def test_attrs_empty(self):
1217        self.verify_empty_attrs(AttributesImpl({}))
1218
1219    def test_attrs_wattr(self):
1220        self.verify_attrs_wattr(AttributesImpl({"attr" : "val"}))
1221
1222    def test_nsattrs_empty(self):
1223        self.verify_empty_nsattrs(AttributesNSImpl({}, {}))
1224
1225    def test_nsattrs_wattr(self):
1226        attrs = AttributesNSImpl({(ns_uri, "attr") : "val"},
1227                                 {(ns_uri, "attr") : "ns:attr"})
1228
1229        self.assertEqual(attrs.getLength(), 1)
1230        self.assertEqual(attrs.getNames(), [(ns_uri, "attr")])
1231        self.assertEqual(attrs.getQNames(), ["ns:attr"])
1232        self.assertEqual(len(attrs), 1)
1233        self.assertIn((ns_uri, "attr"), attrs)
1234        self.assertEqual(list(attrs.keys()), [(ns_uri, "attr")])
1235        self.assertEqual(attrs.get((ns_uri, "attr")), "val")
1236        self.assertEqual(attrs.get((ns_uri, "attr"), 25), "val")
1237        self.assertEqual(list(attrs.items()), [((ns_uri, "attr"), "val")])
1238        self.assertEqual(list(attrs.values()), ["val"])
1239        self.assertEqual(attrs.getValue((ns_uri, "attr")), "val")
1240        self.assertEqual(attrs.getValueByQName("ns:attr"), "val")
1241        self.assertEqual(attrs.getNameByQName("ns:attr"), (ns_uri, "attr"))
1242        self.assertEqual(attrs[(ns_uri, "attr")], "val")
1243        self.assertEqual(attrs.getQNameByName((ns_uri, "attr")), "ns:attr")
1244
1245
1246def test_main():
1247    run_unittest(MakeParserTest,
1248                 ParseTest,
1249                 SaxutilsTest,
1250                 PrepareInputSourceTest,
1251                 StringXmlgenTest,
1252                 BytesXmlgenTest,
1253                 WriterXmlgenTest,
1254                 StreamWriterXmlgenTest,
1255                 StreamReaderWriterXmlgenTest,
1256                 ExpatReaderTest,
1257                 ErrorReportingTest,
1258                 XmlReaderTest)
1259
1260if __name__ == "__main__":
1261    test_main()
1262