1"""\
2A library of useful helper classes to the SAX classes, for the
3convenience of application and driver writers.
4"""
5
6import os, urllib.parse, urllib.request
7import io
8import codecs
9from . import handler
10from . import xmlreader
11
12def __dict_replace(s, d):
13    """Replace substrings of a string using a dictionary."""
14    for key, value in d.items():
15        s = s.replace(key, value)
16    return s
17
18def escape(data, entities={}):
19    """Escape &, <, and > in a string of data.
20
21    You can escape other strings of data by passing a dictionary as
22    the optional entities parameter.  The keys and values must all be
23    strings; each key will be replaced with its corresponding value.
24    """
25
26    # must do ampersand first
27    data = data.replace("&", "&amp;")
28    data = data.replace(">", "&gt;")
29    data = data.replace("<", "&lt;")
30    if entities:
31        data = __dict_replace(data, entities)
32    return data
33
34def unescape(data, entities={}):
35    """Unescape &amp;, &lt;, and &gt; in a string of data.
36
37    You can unescape other strings of data by passing a dictionary as
38    the optional entities parameter.  The keys and values must all be
39    strings; each key will be replaced with its corresponding value.
40    """
41    data = data.replace("&lt;", "<")
42    data = data.replace("&gt;", ">")
43    if entities:
44        data = __dict_replace(data, entities)
45    # must do ampersand last
46    return data.replace("&amp;", "&")
47
48def quoteattr(data, entities={}):
49    """Escape and quote an attribute value.
50
51    Escape &, <, and > in a string of data, then quote it for use as
52    an attribute value.  The \" character will be escaped as well, if
53    necessary.
54
55    You can escape other strings of data by passing a dictionary as
56    the optional entities parameter.  The keys and values must all be
57    strings; each key will be replaced with its corresponding value.
58    """
59    entities = entities.copy()
60    entities.update({'\n': '&#10;', '\r': '&#13;', '\t':'&#9;'})
61    data = escape(data, entities)
62    if '"' in data:
63        if "'" in data:
64            data = '"%s"' % data.replace('"', "&quot;")
65        else:
66            data = "'%s'" % data
67    else:
68        data = '"%s"' % data
69    return data
70
71
72def _gettextwriter(out, encoding):
73    if out is None:
74        import sys
75        return sys.stdout
76
77    if isinstance(out, io.TextIOBase):
78        # use a text writer as is
79        return out
80
81    if isinstance(out, (codecs.StreamWriter, codecs.StreamReaderWriter)):
82        # use a codecs stream writer as is
83        return out
84
85    # wrap a binary writer with TextIOWrapper
86    if isinstance(out, io.RawIOBase):
87        # Keep the original file open when the TextIOWrapper is
88        # destroyed
89        class _wrapper:
90            __class__ = out.__class__
91            def __getattr__(self, name):
92                return getattr(out, name)
93        buffer = _wrapper()
94        buffer.close = lambda: None
95    else:
96        # This is to handle passed objects that aren't in the
97        # IOBase hierarchy, but just have a write method
98        buffer = io.BufferedIOBase()
99        buffer.writable = lambda: True
100        buffer.write = out.write
101        try:
102            # TextIOWrapper uses this methods to determine
103            # if BOM (for UTF-16, etc) should be added
104            buffer.seekable = out.seekable
105            buffer.tell = out.tell
106        except AttributeError:
107            pass
108    return io.TextIOWrapper(buffer, encoding=encoding,
109                            errors='xmlcharrefreplace',
110                            newline='\n',
111                            write_through=True)
112
113class XMLGenerator(handler.ContentHandler):
114
115    def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
116        handler.ContentHandler.__init__(self)
117        out = _gettextwriter(out, encoding)
118        self._write = out.write
119        self._flush = out.flush
120        self._ns_contexts = [{}] # contains uri -> prefix dicts
121        self._current_context = self._ns_contexts[-1]
122        self._undeclared_ns_maps = []
123        self._encoding = encoding
124        self._short_empty_elements = short_empty_elements
125        self._pending_start_element = False
126
127    def _qname(self, name):
128        """Builds a qualified name from a (ns_url, localname) pair"""
129        if name[0]:
130            # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is
131            # bound by definition to http://www.w3.org/XML/1998/namespace.  It
132            # does not need to be declared and will not usually be found in
133            # self._current_context.
134            if 'http://www.w3.org/XML/1998/namespace' == name[0]:
135                return 'xml:' + name[1]
136            # The name is in a non-empty namespace
137            prefix = self._current_context[name[0]]
138            if prefix:
139                # If it is not the default namespace, prepend the prefix
140                return prefix + ":" + name[1]
141        # Return the unqualified name
142        return name[1]
143
144    def _finish_pending_start_element(self,endElement=False):
145        if self._pending_start_element:
146            self._write('>')
147            self._pending_start_element = False
148
149    # ContentHandler methods
150
151    def startDocument(self):
152        self._write('<?xml version="1.0" encoding="%s"?>\n' %
153                        self._encoding)
154
155    def endDocument(self):
156        self._flush()
157
158    def startPrefixMapping(self, prefix, uri):
159        self._ns_contexts.append(self._current_context.copy())
160        self._current_context[uri] = prefix
161        self._undeclared_ns_maps.append((prefix, uri))
162
163    def endPrefixMapping(self, prefix):
164        self._current_context = self._ns_contexts[-1]
165        del self._ns_contexts[-1]
166
167    def startElement(self, name, attrs):
168        self._finish_pending_start_element()
169        self._write('<' + name)
170        for (name, value) in attrs.items():
171            self._write(' %s=%s' % (name, quoteattr(value)))
172        if self._short_empty_elements:
173            self._pending_start_element = True
174        else:
175            self._write(">")
176
177    def endElement(self, name):
178        if self._pending_start_element:
179            self._write('/>')
180            self._pending_start_element = False
181        else:
182            self._write('</%s>' % name)
183
184    def startElementNS(self, name, qname, attrs):
185        self._finish_pending_start_element()
186        self._write('<' + self._qname(name))
187
188        for prefix, uri in self._undeclared_ns_maps:
189            if prefix:
190                self._write(' xmlns:%s="%s"' % (prefix, uri))
191            else:
192                self._write(' xmlns="%s"' % uri)
193        self._undeclared_ns_maps = []
194
195        for (name, value) in attrs.items():
196            self._write(' %s=%s' % (self._qname(name), quoteattr(value)))
197        if self._short_empty_elements:
198            self._pending_start_element = True
199        else:
200            self._write(">")
201
202    def endElementNS(self, name, qname):
203        if self._pending_start_element:
204            self._write('/>')
205            self._pending_start_element = False
206        else:
207            self._write('</%s>' % self._qname(name))
208
209    def characters(self, content):
210        if content:
211            self._finish_pending_start_element()
212            if not isinstance(content, str):
213                content = str(content, self._encoding)
214            self._write(escape(content))
215
216    def ignorableWhitespace(self, content):
217        if content:
218            self._finish_pending_start_element()
219            if not isinstance(content, str):
220                content = str(content, self._encoding)
221            self._write(content)
222
223    def processingInstruction(self, target, data):
224        self._finish_pending_start_element()
225        self._write('<?%s %s?>' % (target, data))
226
227
228class XMLFilterBase(xmlreader.XMLReader):
229    """This class is designed to sit between an XMLReader and the
230    client application's event handlers.  By default, it does nothing
231    but pass requests up to the reader and events on to the handlers
232    unmodified, but subclasses can override specific methods to modify
233    the event stream or the configuration requests as they pass
234    through."""
235
236    def __init__(self, parent = None):
237        xmlreader.XMLReader.__init__(self)
238        self._parent = parent
239
240    # ErrorHandler methods
241
242    def error(self, exception):
243        self._err_handler.error(exception)
244
245    def fatalError(self, exception):
246        self._err_handler.fatalError(exception)
247
248    def warning(self, exception):
249        self._err_handler.warning(exception)
250
251    # ContentHandler methods
252
253    def setDocumentLocator(self, locator):
254        self._cont_handler.setDocumentLocator(locator)
255
256    def startDocument(self):
257        self._cont_handler.startDocument()
258
259    def endDocument(self):
260        self._cont_handler.endDocument()
261
262    def startPrefixMapping(self, prefix, uri):
263        self._cont_handler.startPrefixMapping(prefix, uri)
264
265    def endPrefixMapping(self, prefix):
266        self._cont_handler.endPrefixMapping(prefix)
267
268    def startElement(self, name, attrs):
269        self._cont_handler.startElement(name, attrs)
270
271    def endElement(self, name):
272        self._cont_handler.endElement(name)
273
274    def startElementNS(self, name, qname, attrs):
275        self._cont_handler.startElementNS(name, qname, attrs)
276
277    def endElementNS(self, name, qname):
278        self._cont_handler.endElementNS(name, qname)
279
280    def characters(self, content):
281        self._cont_handler.characters(content)
282
283    def ignorableWhitespace(self, chars):
284        self._cont_handler.ignorableWhitespace(chars)
285
286    def processingInstruction(self, target, data):
287        self._cont_handler.processingInstruction(target, data)
288
289    def skippedEntity(self, name):
290        self._cont_handler.skippedEntity(name)
291
292    # DTDHandler methods
293
294    def notationDecl(self, name, publicId, systemId):
295        self._dtd_handler.notationDecl(name, publicId, systemId)
296
297    def unparsedEntityDecl(self, name, publicId, systemId, ndata):
298        self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata)
299
300    # EntityResolver methods
301
302    def resolveEntity(self, publicId, systemId):
303        return self._ent_handler.resolveEntity(publicId, systemId)
304
305    # XMLReader methods
306
307    def parse(self, source):
308        self._parent.setContentHandler(self)
309        self._parent.setErrorHandler(self)
310        self._parent.setEntityResolver(self)
311        self._parent.setDTDHandler(self)
312        self._parent.parse(source)
313
314    def setLocale(self, locale):
315        self._parent.setLocale(locale)
316
317    def getFeature(self, name):
318        return self._parent.getFeature(name)
319
320    def setFeature(self, name, state):
321        self._parent.setFeature(name, state)
322
323    def getProperty(self, name):
324        return self._parent.getProperty(name)
325
326    def setProperty(self, name, value):
327        self._parent.setProperty(name, value)
328
329    # XMLFilter methods
330
331    def getParent(self):
332        return self._parent
333
334    def setParent(self, parent):
335        self._parent = parent
336
337# --- Utility functions
338
339def prepare_input_source(source, base=""):
340    """This function takes an InputSource and an optional base URL and
341    returns a fully resolved InputSource object ready for reading."""
342
343    if isinstance(source, str):
344        source = xmlreader.InputSource(source)
345    elif hasattr(source, "read"):
346        f = source
347        source = xmlreader.InputSource()
348        if isinstance(f.read(0), str):
349            source.setCharacterStream(f)
350        else:
351            source.setByteStream(f)
352        if hasattr(f, "name") and isinstance(f.name, str):
353            source.setSystemId(f.name)
354
355    if source.getCharacterStream() is None and source.getByteStream() is None:
356        sysid = source.getSystemId()
357        basehead = os.path.dirname(os.path.normpath(base))
358        sysidfilename = os.path.join(basehead, sysid)
359        if os.path.isfile(sysidfilename):
360            source.setSystemId(sysidfilename)
361            f = open(sysidfilename, "rb")
362        else:
363            source.setSystemId(urllib.parse.urljoin(base, sysid))
364            f = urllib.request.urlopen(source.getSystemId())
365
366        source.setByteStream(f)
367
368    return source
369