test_pyexpat.py revision 0ddbf4795f40d2d3386dc56ffa264b50a015f6c9
1# XXX TypeErrors on calling handlers, or on bad return values from a
2# handler, are obscure and unhelpful.
3
4from io import BytesIO
5import os
6import unittest
7import traceback
8
9from xml.parsers import expat
10from xml.parsers.expat import errors
11
12from test.support import sortdict, run_unittest
13
14
15class SetAttributeTest(unittest.TestCase):
16    def setUp(self):
17        self.parser = expat.ParserCreate(namespace_separator='!')
18        self.set_get_pairs = [
19            [0, 0],
20            [1, 1],
21            [2, 1],
22            [0, 0],
23            ]
24
25    def test_ordered_attributes(self):
26        for x, y in self.set_get_pairs:
27            self.parser.ordered_attributes = x
28            self.assertEqual(self.parser.ordered_attributes, y)
29
30    def test_specified_attributes(self):
31        for x, y in self.set_get_pairs:
32            self.parser.specified_attributes = x
33            self.assertEqual(self.parser.specified_attributes, y)
34
35
36data = b'''\
37<?xml version="1.0" encoding="iso-8859-1" standalone="no"?>
38<?xml-stylesheet href="stylesheet.css"?>
39<!-- comment data -->
40<!DOCTYPE quotations SYSTEM "quotations.dtd" [
41<!ELEMENT root ANY>
42<!ATTLIST root attr1 CDATA #REQUIRED attr2 CDATA #IMPLIED>
43<!NOTATION notation SYSTEM "notation.jpeg">
44<!ENTITY acirc "&#226;">
45<!ENTITY external_entity SYSTEM "entity.file">
46<!ENTITY unparsed_entity SYSTEM "entity.file" NDATA notation>
47%unparsed_entity;
48]>
49
50<root attr1="value1" attr2="value2&#8000;">
51<myns:subelement xmlns:myns="http://www.python.org/namespace">
52     Contents of subelements
53</myns:subelement>
54<sub2><![CDATA[contents of CDATA section]]></sub2>
55&external_entity;
56&skipped_entity;
57\xb5
58</root>
59'''
60
61
62# Produce UTF-8 output
63class ParseTest(unittest.TestCase):
64    class Outputter:
65        def __init__(self):
66            self.out = []
67
68        def StartElementHandler(self, name, attrs):
69            self.out.append('Start element: ' + repr(name) + ' ' +
70                            sortdict(attrs))
71
72        def EndElementHandler(self, name):
73            self.out.append('End element: ' + repr(name))
74
75        def CharacterDataHandler(self, data):
76            data = data.strip()
77            if data:
78                self.out.append('Character data: ' + repr(data))
79
80        def ProcessingInstructionHandler(self, target, data):
81            self.out.append('PI: ' + repr(target) + ' ' + repr(data))
82
83        def StartNamespaceDeclHandler(self, prefix, uri):
84            self.out.append('NS decl: ' + repr(prefix) + ' ' + repr(uri))
85
86        def EndNamespaceDeclHandler(self, prefix):
87            self.out.append('End of NS decl: ' + repr(prefix))
88
89        def StartCdataSectionHandler(self):
90            self.out.append('Start of CDATA section')
91
92        def EndCdataSectionHandler(self):
93            self.out.append('End of CDATA section')
94
95        def CommentHandler(self, text):
96            self.out.append('Comment: ' + repr(text))
97
98        def NotationDeclHandler(self, *args):
99            name, base, sysid, pubid = args
100            self.out.append('Notation declared: %s' %(args,))
101
102        def UnparsedEntityDeclHandler(self, *args):
103            entityName, base, systemId, publicId, notationName = args
104            self.out.append('Unparsed entity decl: %s' %(args,))
105
106        def NotStandaloneHandler(self):
107            self.out.append('Not standalone')
108            return 1
109
110        def ExternalEntityRefHandler(self, *args):
111            context, base, sysId, pubId = args
112            self.out.append('External entity ref: %s' %(args[1:],))
113            return 1
114
115        def StartDoctypeDeclHandler(self, *args):
116            self.out.append(('Start doctype', args))
117            return 1
118
119        def EndDoctypeDeclHandler(self):
120            self.out.append("End doctype")
121            return 1
122
123        def EntityDeclHandler(self, *args):
124            self.out.append(('Entity declaration', args))
125            return 1
126
127        def XmlDeclHandler(self, *args):
128            self.out.append(('XML declaration', args))
129            return 1
130
131        def ElementDeclHandler(self, *args):
132            self.out.append(('Element declaration', args))
133            return 1
134
135        def AttlistDeclHandler(self, *args):
136            self.out.append(('Attribute list declaration', args))
137            return 1
138
139        def SkippedEntityHandler(self, *args):
140            self.out.append(("Skipped entity", args))
141            return 1
142
143        def DefaultHandler(self, userData):
144            pass
145
146        def DefaultHandlerExpand(self, userData):
147            pass
148
149    handler_names = [
150        'StartElementHandler', 'EndElementHandler', 'CharacterDataHandler',
151        'ProcessingInstructionHandler', 'UnparsedEntityDeclHandler',
152        'NotationDeclHandler', 'StartNamespaceDeclHandler',
153        'EndNamespaceDeclHandler', 'CommentHandler',
154        'StartCdataSectionHandler', 'EndCdataSectionHandler', 'DefaultHandler',
155        'DefaultHandlerExpand', 'NotStandaloneHandler',
156        'ExternalEntityRefHandler', 'StartDoctypeDeclHandler',
157        'EndDoctypeDeclHandler', 'EntityDeclHandler', 'XmlDeclHandler',
158        'ElementDeclHandler', 'AttlistDeclHandler', 'SkippedEntityHandler',
159        ]
160
161    def _hookup_callbacks(self, parser, handler):
162        """
163        Set each of the callbacks defined on handler and named in
164        self.handler_names on the given parser.
165        """
166        for name in self.handler_names:
167            setattr(parser, name, getattr(handler, name))
168
169    def _verify_parse_output(self, operations):
170        expected_operations = [
171            ('XML declaration', ('1.0', 'iso-8859-1', 0)),
172            'PI: \'xml-stylesheet\' \'href="stylesheet.css"\'',
173            "Comment: ' comment data '",
174            "Not standalone",
175            ("Start doctype", ('quotations', 'quotations.dtd', None, 1)),
176            ('Element declaration', ('root', (2, 0, None, ()))),
177            ('Attribute list declaration', ('root', 'attr1', 'CDATA', None,
178                1)),
179            ('Attribute list declaration', ('root', 'attr2', 'CDATA', None,
180                0)),
181            "Notation declared: ('notation', None, 'notation.jpeg', None)",
182            ('Entity declaration', ('acirc', 0, '\xe2', None, None, None, None)),
183            ('Entity declaration', ('external_entity', 0, None, None,
184                'entity.file', None, None)),
185            "Unparsed entity decl: ('unparsed_entity', None, 'entity.file', None, 'notation')",
186            "Not standalone",
187            "End doctype",
188            "Start element: 'root' {'attr1': 'value1', 'attr2': 'value2\u1f40'}",
189            "NS decl: 'myns' 'http://www.python.org/namespace'",
190            "Start element: 'http://www.python.org/namespace!subelement' {}",
191            "Character data: 'Contents of subelements'",
192            "End element: 'http://www.python.org/namespace!subelement'",
193            "End of NS decl: 'myns'",
194            "Start element: 'sub2' {}",
195            'Start of CDATA section',
196            "Character data: 'contents of CDATA section'",
197            'End of CDATA section',
198            "End element: 'sub2'",
199            "External entity ref: (None, 'entity.file', None)",
200            ('Skipped entity', ('skipped_entity', 0)),
201            "Character data: '\xb5'",
202            "End element: 'root'",
203        ]
204        for operation, expected_operation in zip(operations, expected_operations):
205            self.assertEqual(operation, expected_operation)
206
207    def test_parse_bytes(self):
208        out = self.Outputter()
209        parser = expat.ParserCreate(namespace_separator='!')
210        self._hookup_callbacks(parser, out)
211
212        parser.Parse(data, 1)
213
214        operations = out.out
215        self._verify_parse_output(operations)
216        # Issue #6697.
217        self.assertRaises(AttributeError, getattr, parser, '\uD800')
218
219    def test_parse_str(self):
220        out = self.Outputter()
221        parser = expat.ParserCreate(namespace_separator='!')
222        self._hookup_callbacks(parser, out)
223
224        parser.Parse(data.decode('iso-8859-1'), 1)
225
226        operations = out.out
227        self._verify_parse_output(operations)
228
229    def test_parse_file(self):
230        # Try parsing a file
231        out = self.Outputter()
232        parser = expat.ParserCreate(namespace_separator='!')
233        self._hookup_callbacks(parser, out)
234        file = BytesIO(data)
235
236        parser.ParseFile(file)
237
238        operations = out.out
239        self._verify_parse_output(operations)
240
241    def test_parse_again(self):
242        parser = expat.ParserCreate()
243        file = BytesIO(data)
244        parser.ParseFile(file)
245        # Issue 6676: ensure a meaningful exception is raised when attempting
246        # to parse more than one XML document per xmlparser instance,
247        # a limitation of the Expat library.
248        with self.assertRaises(expat.error) as cm:
249            parser.ParseFile(file)
250        self.assertEqual(expat.ErrorString(cm.exception.code),
251                          expat.errors.XML_ERROR_FINISHED)
252
253class NamespaceSeparatorTest(unittest.TestCase):
254    def test_legal(self):
255        # Tests that make sure we get errors when the namespace_separator value
256        # is illegal, and that we don't for good values:
257        expat.ParserCreate()
258        expat.ParserCreate(namespace_separator=None)
259        expat.ParserCreate(namespace_separator=' ')
260
261    def test_illegal(self):
262        try:
263            expat.ParserCreate(namespace_separator=42)
264            self.fail()
265        except TypeError as e:
266            self.assertEqual(str(e),
267                'ParserCreate() argument 2 must be str or None, not int')
268
269        try:
270            expat.ParserCreate(namespace_separator='too long')
271            self.fail()
272        except ValueError as e:
273            self.assertEqual(str(e),
274                'namespace_separator must be at most one character, omitted, or None')
275
276    def test_zero_length(self):
277        # ParserCreate() needs to accept a namespace_separator of zero length
278        # to satisfy the requirements of RDF applications that are required
279        # to simply glue together the namespace URI and the localname.  Though
280        # considered a wart of the RDF specifications, it needs to be supported.
281        #
282        # See XML-SIG mailing list thread starting with
283        # http://mail.python.org/pipermail/xml-sig/2001-April/005202.html
284        #
285        expat.ParserCreate(namespace_separator='') # too short
286
287
288class InterningTest(unittest.TestCase):
289    def test(self):
290        # Test the interning machinery.
291        p = expat.ParserCreate()
292        L = []
293        def collector(name, *args):
294            L.append(name)
295        p.StartElementHandler = collector
296        p.EndElementHandler = collector
297        p.Parse(b"<e> <e/> <e></e> </e>", 1)
298        tag = L[0]
299        self.assertEqual(len(L), 6)
300        for entry in L:
301            # L should have the same string repeated over and over.
302            self.assertTrue(tag is entry)
303
304    def test_issue9402(self):
305        # create an ExternalEntityParserCreate with buffer text
306        class ExternalOutputter:
307            def __init__(self, parser):
308                self.parser = parser
309                self.parser_result = None
310
311            def ExternalEntityRefHandler(self, context, base, sysId, pubId):
312                external_parser = self.parser.ExternalEntityParserCreate("")
313                self.parser_result = external_parser.Parse(b"", 1)
314                return 1
315
316        parser = expat.ParserCreate(namespace_separator='!')
317        parser.buffer_text = 1
318        out = ExternalOutputter(parser)
319        parser.ExternalEntityRefHandler = out.ExternalEntityRefHandler
320        parser.Parse(data, 1)
321        self.assertEqual(out.parser_result, 1)
322
323
324class BufferTextTest(unittest.TestCase):
325    def setUp(self):
326        self.stuff = []
327        self.parser = expat.ParserCreate()
328        self.parser.buffer_text = 1
329        self.parser.CharacterDataHandler = self.CharacterDataHandler
330
331    def check(self, expected, label):
332        self.assertEqual(self.stuff, expected,
333                "%s\nstuff    = %r\nexpected = %r"
334                % (label, self.stuff, map(str, expected)))
335
336    def CharacterDataHandler(self, text):
337        self.stuff.append(text)
338
339    def StartElementHandler(self, name, attrs):
340        self.stuff.append("<%s>" % name)
341        bt = attrs.get("buffer-text")
342        if bt == "yes":
343            self.parser.buffer_text = 1
344        elif bt == "no":
345            self.parser.buffer_text = 0
346
347    def EndElementHandler(self, name):
348        self.stuff.append("</%s>" % name)
349
350    def CommentHandler(self, data):
351        self.stuff.append("<!--%s-->" % data)
352
353    def setHandlers(self, handlers=[]):
354        for name in handlers:
355            setattr(self.parser, name, getattr(self, name))
356
357    def test_default_to_disabled(self):
358        parser = expat.ParserCreate()
359        self.assertFalse(parser.buffer_text)
360
361    def test_buffering_enabled(self):
362        # Make sure buffering is turned on
363        self.assertTrue(self.parser.buffer_text)
364        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", 1)
365        self.assertEqual(self.stuff, ['123'],
366                         "buffered text not properly collapsed")
367
368    def test1(self):
369        # XXX This test exposes more detail of Expat's text chunking than we
370        # XXX like, but it tests what we need to concisely.
371        self.setHandlers(["StartElementHandler"])
372        self.parser.Parse(b"<a>1<b buffer-text='no'/>2\n3<c buffer-text='yes'/>4\n5</a>", 1)
373        self.assertEqual(self.stuff,
374                         ["<a>", "1", "<b>", "2", "\n", "3", "<c>", "4\n5"],
375                         "buffering control not reacting as expected")
376
377    def test2(self):
378        self.parser.Parse(b"<a>1<b/>&lt;2&gt;<c/>&#32;\n&#x20;3</a>", 1)
379        self.assertEqual(self.stuff, ["1<2> \n 3"],
380                         "buffered text not properly collapsed")
381
382    def test3(self):
383        self.setHandlers(["StartElementHandler"])
384        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", 1)
385        self.assertEqual(self.stuff, ["<a>", "1", "<b>", "2", "<c>", "3"],
386                         "buffered text not properly split")
387
388    def test4(self):
389        self.setHandlers(["StartElementHandler", "EndElementHandler"])
390        self.parser.CharacterDataHandler = None
391        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", 1)
392        self.assertEqual(self.stuff,
393                         ["<a>", "<b>", "</b>", "<c>", "</c>", "</a>"])
394
395    def test5(self):
396        self.setHandlers(["StartElementHandler", "EndElementHandler"])
397        self.parser.Parse(b"<a>1<b></b>2<c/>3</a>", 1)
398        self.assertEqual(self.stuff,
399            ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3", "</a>"])
400
401    def test6(self):
402        self.setHandlers(["CommentHandler", "EndElementHandler",
403                    "StartElementHandler"])
404        self.parser.Parse(b"<a>1<b/>2<c></c>345</a> ", 1)
405        self.assertEqual(self.stuff,
406            ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "345", "</a>"],
407            "buffered text not properly split")
408
409    def test7(self):
410        self.setHandlers(["CommentHandler", "EndElementHandler",
411                    "StartElementHandler"])
412        self.parser.Parse(b"<a>1<b/>2<c></c>3<!--abc-->4<!--def-->5</a> ", 1)
413        self.assertEqual(self.stuff,
414                         ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3",
415                          "<!--abc-->", "4", "<!--def-->", "5", "</a>"],
416                         "buffered text not properly split")
417
418
419# Test handling of exception from callback:
420class HandlerExceptionTest(unittest.TestCase):
421    def StartElementHandler(self, name, attrs):
422        raise RuntimeError(name)
423
424    def check_traceback_entry(self, entry, filename, funcname):
425        self.assertEqual(os.path.basename(entry[0]), filename)
426        self.assertEqual(entry[2], funcname)
427
428    def test_exception(self):
429        parser = expat.ParserCreate()
430        parser.StartElementHandler = self.StartElementHandler
431        try:
432            parser.Parse(b"<a><b><c/></b></a>", 1)
433            self.fail()
434        except RuntimeError as e:
435            self.assertEqual(e.args[0], 'a',
436                             "Expected RuntimeError for element 'a', but" + \
437                             " found %r" % e.args[0])
438            # Check that the traceback contains the relevant line in pyexpat.c
439            entries = traceback.extract_tb(e.__traceback__)
440            self.assertEqual(len(entries), 3)
441            self.check_traceback_entry(entries[0],
442                                       "test_pyexpat.py", "test_exception")
443            self.check_traceback_entry(entries[1],
444                                       "pyexpat.c", "StartElement")
445            self.check_traceback_entry(entries[2],
446                                       "test_pyexpat.py", "StartElementHandler")
447            self.assertIn('call_with_frame("StartElement"', entries[1][3])
448
449
450# Test Current* members:
451class PositionTest(unittest.TestCase):
452    def StartElementHandler(self, name, attrs):
453        self.check_pos('s')
454
455    def EndElementHandler(self, name):
456        self.check_pos('e')
457
458    def check_pos(self, event):
459        pos = (event,
460               self.parser.CurrentByteIndex,
461               self.parser.CurrentLineNumber,
462               self.parser.CurrentColumnNumber)
463        self.assertTrue(self.upto < len(self.expected_list),
464                        'too many parser events')
465        expected = self.expected_list[self.upto]
466        self.assertEqual(pos, expected,
467                'Expected position %s, got position %s' %(pos, expected))
468        self.upto += 1
469
470    def test(self):
471        self.parser = expat.ParserCreate()
472        self.parser.StartElementHandler = self.StartElementHandler
473        self.parser.EndElementHandler = self.EndElementHandler
474        self.upto = 0
475        self.expected_list = [('s', 0, 1, 0), ('s', 5, 2, 1), ('s', 11, 3, 2),
476                              ('e', 15, 3, 6), ('e', 17, 4, 1), ('e', 22, 5, 0)]
477
478        xml = b'<a>\n <b>\n  <c/>\n </b>\n</a>'
479        self.parser.Parse(xml, 1)
480
481
482class sf1296433Test(unittest.TestCase):
483    def test_parse_only_xml_data(self):
484        # http://python.org/sf/1296433
485        #
486        xml = "<?xml version='1.0' encoding='iso8859'?><s>%s</s>" % ('a' * 1025)
487        # this one doesn't crash
488        #xml = "<?xml version='1.0'?><s>%s</s>" % ('a' * 10000)
489
490        class SpecificException(Exception):
491            pass
492
493        def handler(text):
494            raise SpecificException
495
496        parser = expat.ParserCreate()
497        parser.CharacterDataHandler = handler
498
499        self.assertRaises(Exception, parser.Parse, xml.encode('iso8859'))
500
501class ChardataBufferTest(unittest.TestCase):
502    """
503    test setting of chardata buffer size
504    """
505
506    def test_1025_bytes(self):
507        self.assertEqual(self.small_buffer_test(1025), 2)
508
509    def test_1000_bytes(self):
510        self.assertEqual(self.small_buffer_test(1000), 1)
511
512    def test_wrong_size(self):
513        parser = expat.ParserCreate()
514        parser.buffer_text = 1
515        def f(size):
516            parser.buffer_size = size
517
518        self.assertRaises(ValueError, f, -1)
519        self.assertRaises(ValueError, f, 0)
520
521    def test_unchanged_size(self):
522        xml1 = b"<?xml version='1.0' encoding='iso8859'?><s>" + b'a' * 512
523        xml2 = b'a'*512 + b'</s>'
524        parser = expat.ParserCreate()
525        parser.CharacterDataHandler = self.counting_handler
526        parser.buffer_size = 512
527        parser.buffer_text = 1
528
529        # Feed 512 bytes of character data: the handler should be called
530        # once.
531        self.n = 0
532        parser.Parse(xml1)
533        self.assertEqual(self.n, 1)
534
535        # Reassign to buffer_size, but assign the same size.
536        parser.buffer_size = parser.buffer_size
537        self.assertEqual(self.n, 1)
538
539        # Try parsing rest of the document
540        parser.Parse(xml2)
541        self.assertEqual(self.n, 2)
542
543
544    def test_disabling_buffer(self):
545        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a>" + b'a' * 512
546        xml2 = b'b' * 1024
547        xml3 = b'c' * 1024 + b'</a>';
548        parser = expat.ParserCreate()
549        parser.CharacterDataHandler = self.counting_handler
550        parser.buffer_text = 1
551        parser.buffer_size = 1024
552        self.assertEqual(parser.buffer_size, 1024)
553
554        # Parse one chunk of XML
555        self.n = 0
556        parser.Parse(xml1, 0)
557        self.assertEqual(parser.buffer_size, 1024)
558        self.assertEqual(self.n, 1)
559
560        # Turn off buffering and parse the next chunk.
561        parser.buffer_text = 0
562        self.assertFalse(parser.buffer_text)
563        self.assertEqual(parser.buffer_size, 1024)
564        for i in range(10):
565            parser.Parse(xml2, 0)
566        self.assertEqual(self.n, 11)
567
568        parser.buffer_text = 1
569        self.assertTrue(parser.buffer_text)
570        self.assertEqual(parser.buffer_size, 1024)
571        parser.Parse(xml3, 1)
572        self.assertEqual(self.n, 12)
573
574    def counting_handler(self, text):
575        self.n += 1
576
577    def small_buffer_test(self, buffer_len):
578        xml = b"<?xml version='1.0' encoding='iso8859'?><s>" + b'a' * buffer_len + b'</s>'
579        parser = expat.ParserCreate()
580        parser.CharacterDataHandler = self.counting_handler
581        parser.buffer_size = 1024
582        parser.buffer_text = 1
583
584        self.n = 0
585        parser.Parse(xml)
586        return self.n
587
588    def test_change_size_1(self):
589        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a><s>" + b'a' * 1024
590        xml2 = b'aaa</s><s>' + b'a' * 1025 + b'</s></a>'
591        parser = expat.ParserCreate()
592        parser.CharacterDataHandler = self.counting_handler
593        parser.buffer_text = 1
594        parser.buffer_size = 1024
595        self.assertEqual(parser.buffer_size, 1024)
596
597        self.n = 0
598        parser.Parse(xml1, 0)
599        parser.buffer_size *= 2
600        self.assertEqual(parser.buffer_size, 2048)
601        parser.Parse(xml2, 1)
602        self.assertEqual(self.n, 2)
603
604    def test_change_size_2(self):
605        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a>a<s>" + b'a' * 1023
606        xml2 = b'aaa</s><s>' + b'a' * 1025 + b'</s></a>'
607        parser = expat.ParserCreate()
608        parser.CharacterDataHandler = self.counting_handler
609        parser.buffer_text = 1
610        parser.buffer_size = 2048
611        self.assertEqual(parser.buffer_size, 2048)
612
613        self.n=0
614        parser.Parse(xml1, 0)
615        parser.buffer_size = parser.buffer_size // 2
616        self.assertEqual(parser.buffer_size, 1024)
617        parser.Parse(xml2, 1)
618        self.assertEqual(self.n, 4)
619
620class MalformedInputTest(unittest.TestCase):
621    def test1(self):
622        xml = b"\0\r\n"
623        parser = expat.ParserCreate()
624        try:
625            parser.Parse(xml, True)
626            self.fail()
627        except expat.ExpatError as e:
628            self.assertEqual(str(e), 'unclosed token: line 2, column 0')
629
630    def test2(self):
631        # \xc2\x85 is UTF-8 encoded U+0085 (NEXT LINE)
632        xml = b"<?xml version\xc2\x85='1.0'?>\r\n"
633        parser = expat.ParserCreate()
634        try:
635            parser.Parse(xml, True)
636            self.fail()
637        except expat.ExpatError as e:
638            self.assertEqual(str(e), 'XML declaration not well-formed: line 1, column 14')
639
640class ErrorMessageTest(unittest.TestCase):
641    def test_codes(self):
642        # verify mapping of errors.codes and errors.messages
643        self.assertEqual(errors.XML_ERROR_SYNTAX,
644                         errors.messages[errors.codes[errors.XML_ERROR_SYNTAX]])
645
646    def test_expaterror(self):
647        xml = b'<'
648        parser = expat.ParserCreate()
649        try:
650            parser.Parse(xml, True)
651            self.fail()
652        except expat.ExpatError as e:
653            self.assertEqual(e.code,
654                             errors.codes[errors.XML_ERROR_UNCLOSED_TOKEN])
655
656
657class ForeignDTDTests(unittest.TestCase):
658    """
659    Tests for the UseForeignDTD method of expat parser objects.
660    """
661    def test_use_foreign_dtd(self):
662        """
663        If UseForeignDTD is passed True and a document without an external
664        entity reference is parsed, ExternalEntityRefHandler is first called
665        with None for the public and system ids.
666        """
667        handler_call_args = []
668        def resolve_entity(context, base, system_id, public_id):
669            handler_call_args.append((public_id, system_id))
670            return 1
671
672        parser = expat.ParserCreate()
673        parser.UseForeignDTD(True)
674        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
675        parser.ExternalEntityRefHandler = resolve_entity
676        parser.Parse(b"<?xml version='1.0'?><element/>")
677        self.assertEqual(handler_call_args, [(None, None)])
678
679        # test UseForeignDTD() is equal to UseForeignDTD(True)
680        handler_call_args[:] = []
681
682        parser = expat.ParserCreate()
683        parser.UseForeignDTD()
684        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
685        parser.ExternalEntityRefHandler = resolve_entity
686        parser.Parse(b"<?xml version='1.0'?><element/>")
687        self.assertEqual(handler_call_args, [(None, None)])
688
689    def test_ignore_use_foreign_dtd(self):
690        """
691        If UseForeignDTD is passed True and a document with an external
692        entity reference is parsed, ExternalEntityRefHandler is called with
693        the public and system ids from the document.
694        """
695        handler_call_args = []
696        def resolve_entity(context, base, system_id, public_id):
697            handler_call_args.append((public_id, system_id))
698            return 1
699
700        parser = expat.ParserCreate()
701        parser.UseForeignDTD(True)
702        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
703        parser.ExternalEntityRefHandler = resolve_entity
704        parser.Parse(
705            b"<?xml version='1.0'?><!DOCTYPE foo PUBLIC 'bar' 'baz'><element/>")
706        self.assertEqual(handler_call_args, [("bar", "baz")])
707
708
709def test_main():
710    run_unittest(SetAttributeTest,
711                 ParseTest,
712                 NamespaceSeparatorTest,
713                 InterningTest,
714                 BufferTextTest,
715                 HandlerExceptionTest,
716                 PositionTest,
717                 sf1296433Test,
718                 ChardataBufferTest,
719                 MalformedInputTest,
720                 ErrorMessageTest,
721                 ForeignDTDTests)
722
723if __name__ == "__main__":
724    test_main()
725