saxutils.py revision 93f3910dca86a18b52e39df10a3a58a9e7819af2
1"""\
2A library of useful helper classes to the SAX classes, for the
3convenience of application and driver writers.
4"""
5
6import os, urlparse, urllib, types
7from . import handler
8from . import xmlreader
9
10try:
11    _StringTypes = [types.StringType, types.UnicodeType]
12except AttributeError:
13    try:
14        _StringTypes = [types.StringType]
15    except AttributeError:
16        _StringTypes = [str]
17
18# See whether the xmlcharrefreplace error handler is
19# supported
20try:
21    from codecs import xmlcharrefreplace_errors
22    _error_handling = "xmlcharrefreplace"
23    del xmlcharrefreplace_errors
24except ImportError:
25    _error_handling = "strict"
26
27def __dict_replace(s, d):
28    """Replace substrings of a string using a dictionary."""
29    for key, value in d.items():
30        s = s.replace(key, value)
31    return s
32
33def escape(data, entities={}):
34    """Escape &, <, and > in a string of data.
35
36    You can escape other strings of data by passing a dictionary as
37    the optional entities parameter.  The keys and values must all be
38    strings; each key will be replaced with its corresponding value.
39    """
40
41    # must do ampersand first
42    data = data.replace("&", "&amp;")
43    data = data.replace(">", "&gt;")
44    data = data.replace("<", "&lt;")
45    if entities:
46        data = __dict_replace(data, entities)
47    return data
48
49def unescape(data, entities={}):
50    """Unescape &amp;, &lt;, and &gt; in a string of data.
51
52    You can unescape other strings of data by passing a dictionary as
53    the optional entities parameter.  The keys and values must all be
54    strings; each key will be replaced with its corresponding value.
55    """
56    data = data.replace("&lt;", "<")
57    data = data.replace("&gt;", ">")
58    if entities:
59        data = __dict_replace(data, entities)
60    # must do ampersand last
61    return data.replace("&amp;", "&")
62
63def quoteattr(data, entities={}):
64    """Escape and quote an attribute value.
65
66    Escape &, <, and > in a string of data, then quote it for use as
67    an attribute value.  The \" character will be escaped as well, if
68    necessary.
69
70    You can escape other strings of data by passing a dictionary as
71    the optional entities parameter.  The keys and values must all be
72    strings; each key will be replaced with its corresponding value.
73    """
74    entities = entities.copy()
75    entities.update({'\n': '&#10;', '\r': '&#13;', '\t':'&#9;'})
76    data = escape(data, entities)
77    if '"' in data:
78        if "'" in data:
79            data = '"%s"' % data.replace('"', "&quot;")
80        else:
81            data = "'%s'" % data
82    else:
83        data = '"%s"' % data
84    return data
85
86
87class XMLGenerator(handler.ContentHandler):
88
89    def __init__(self, out=None, encoding="iso-8859-1"):
90        if out is None:
91            import sys
92            out = sys.stdout
93        handler.ContentHandler.__init__(self)
94        self._out = out
95        self._ns_contexts = [{}] # contains uri -> prefix dicts
96        self._current_context = self._ns_contexts[-1]
97        self._undeclared_ns_maps = []
98        self._encoding = encoding
99
100    def _write(self, text):
101        if isinstance(text, str):
102            self._out.write(text)
103        else:
104            self._out.write(text.encode(self._encoding, _error_handling))
105
106    def _qname(self, name):
107        """Builds a qualified name from a (ns_url, localname) pair"""
108        if name[0]:
109            # The name is in a non-empty namespace
110            prefix = self._current_context[name[0]]
111            if prefix:
112                # If it is not the default namespace, prepend the prefix
113                return prefix + ":" + name[1]
114        # Return the unqualified name
115        return name[1]
116
117    # ContentHandler methods
118
119    def startDocument(self):
120        self._write('<?xml version="1.0" encoding="%s"?>\n' %
121                        self._encoding)
122
123    def startPrefixMapping(self, prefix, uri):
124        self._ns_contexts.append(self._current_context.copy())
125        self._current_context[uri] = prefix
126        self._undeclared_ns_maps.append((prefix, uri))
127
128    def endPrefixMapping(self, prefix):
129        self._current_context = self._ns_contexts[-1]
130        del self._ns_contexts[-1]
131
132    def startElement(self, name, attrs):
133        self._write('<' + name)
134        for (name, value) in attrs.items():
135            self._write(' %s=%s' % (name, quoteattr(value)))
136        self._write('>')
137
138    def endElement(self, name):
139        self._write('</%s>' % name)
140
141    def startElementNS(self, name, qname, attrs):
142        self._write('<' + self._qname(name))
143
144        for prefix, uri in self._undeclared_ns_maps:
145            if prefix:
146                self._out.write(' xmlns:%s="%s"' % (prefix, uri))
147            else:
148                self._out.write(' xmlns="%s"' % uri)
149        self._undeclared_ns_maps = []
150
151        for (name, value) in attrs.items():
152            self._write(' %s=%s' % (self._qname(name), quoteattr(value)))
153        self._write('>')
154
155    def endElementNS(self, name, qname):
156        self._write('</%s>' % self._qname(name))
157
158    def characters(self, content):
159        self._write(escape(content))
160
161    def ignorableWhitespace(self, content):
162        self._write(content)
163
164    def processingInstruction(self, target, data):
165        self._write('<?%s %s?>' % (target, data))
166
167
168class XMLFilterBase(xmlreader.XMLReader):
169    """This class is designed to sit between an XMLReader and the
170    client application's event handlers.  By default, it does nothing
171    but pass requests up to the reader and events on to the handlers
172    unmodified, but subclasses can override specific methods to modify
173    the event stream or the configuration requests as they pass
174    through."""
175
176    def __init__(self, parent = None):
177        xmlreader.XMLReader.__init__(self)
178        self._parent = parent
179
180    # ErrorHandler methods
181
182    def error(self, exception):
183        self._err_handler.error(exception)
184
185    def fatalError(self, exception):
186        self._err_handler.fatalError(exception)
187
188    def warning(self, exception):
189        self._err_handler.warning(exception)
190
191    # ContentHandler methods
192
193    def setDocumentLocator(self, locator):
194        self._cont_handler.setDocumentLocator(locator)
195
196    def startDocument(self):
197        self._cont_handler.startDocument()
198
199    def endDocument(self):
200        self._cont_handler.endDocument()
201
202    def startPrefixMapping(self, prefix, uri):
203        self._cont_handler.startPrefixMapping(prefix, uri)
204
205    def endPrefixMapping(self, prefix):
206        self._cont_handler.endPrefixMapping(prefix)
207
208    def startElement(self, name, attrs):
209        self._cont_handler.startElement(name, attrs)
210
211    def endElement(self, name):
212        self._cont_handler.endElement(name)
213
214    def startElementNS(self, name, qname, attrs):
215        self._cont_handler.startElementNS(name, qname, attrs)
216
217    def endElementNS(self, name, qname):
218        self._cont_handler.endElementNS(name, qname)
219
220    def characters(self, content):
221        self._cont_handler.characters(content)
222
223    def ignorableWhitespace(self, chars):
224        self._cont_handler.ignorableWhitespace(chars)
225
226    def processingInstruction(self, target, data):
227        self._cont_handler.processingInstruction(target, data)
228
229    def skippedEntity(self, name):
230        self._cont_handler.skippedEntity(name)
231
232    # DTDHandler methods
233
234    def notationDecl(self, name, publicId, systemId):
235        self._dtd_handler.notationDecl(name, publicId, systemId)
236
237    def unparsedEntityDecl(self, name, publicId, systemId, ndata):
238        self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata)
239
240    # EntityResolver methods
241
242    def resolveEntity(self, publicId, systemId):
243        return self._ent_handler.resolveEntity(publicId, systemId)
244
245    # XMLReader methods
246
247    def parse(self, source):
248        self._parent.setContentHandler(self)
249        self._parent.setErrorHandler(self)
250        self._parent.setEntityResolver(self)
251        self._parent.setDTDHandler(self)
252        self._parent.parse(source)
253
254    def setLocale(self, locale):
255        self._parent.setLocale(locale)
256
257    def getFeature(self, name):
258        return self._parent.getFeature(name)
259
260    def setFeature(self, name, state):
261        self._parent.setFeature(name, state)
262
263    def getProperty(self, name):
264        return self._parent.getProperty(name)
265
266    def setProperty(self, name, value):
267        self._parent.setProperty(name, value)
268
269    # XMLFilter methods
270
271    def getParent(self):
272        return self._parent
273
274    def setParent(self, parent):
275        self._parent = parent
276
277# --- Utility functions
278
279def prepare_input_source(source, base = ""):
280    """This function takes an InputSource and an optional base URL and
281    returns a fully resolved InputSource object ready for reading."""
282
283    if type(source) in _StringTypes:
284        source = xmlreader.InputSource(source)
285    elif hasattr(source, "read"):
286        f = source
287        source = xmlreader.InputSource()
288        source.setByteStream(f)
289        if hasattr(f, "name"):
290            source.setSystemId(f.name)
291
292    if source.getByteStream() is None:
293        sysid = source.getSystemId()
294        basehead = os.path.dirname(os.path.normpath(base))
295        sysidfilename = os.path.join(basehead, sysid)
296        if os.path.isfile(sysidfilename):
297            source.setSystemId(sysidfilename)
298            f = open(sysidfilename, "rb")
299        else:
300            source.setSystemId(urlparse.urljoin(base, sysid))
301            f = urllib.urlopen(source.getSystemId())
302
303        source.setByteStream(f)
304
305    return source
306