1"""\
2A library of useful helper classes to the SAX classes, for the
3convenience of application and driver writers.
4"""
5
6import os, urlparse, urllib, types
7import io
8import sys
9import handler
10import xmlreader
11
12try:
13    _StringTypes = [types.StringType, types.UnicodeType]
14except AttributeError:
15    _StringTypes = [types.StringType]
16
17def __dict_replace(s, d):
18    """Replace substrings of a string using a dictionary."""
19    for key, value in d.items():
20        s = s.replace(key, value)
21    return s
22
23def escape(data, entities={}):
24    """Escape &, <, and > in a string of data.
25
26    You can escape other strings of data by passing a dictionary as
27    the optional entities parameter.  The keys and values must all be
28    strings; each key will be replaced with its corresponding value.
29    """
30
31    # must do ampersand first
32    data = data.replace("&", "&amp;")
33    data = data.replace(">", "&gt;")
34    data = data.replace("<", "&lt;")
35    if entities:
36        data = __dict_replace(data, entities)
37    return data
38
39def unescape(data, entities={}):
40    """Unescape &amp;, &lt;, and &gt; in a string of data.
41
42    You can unescape other strings of data by passing a dictionary as
43    the optional entities parameter.  The keys and values must all be
44    strings; each key will be replaced with its corresponding value.
45    """
46    data = data.replace("&lt;", "<")
47    data = data.replace("&gt;", ">")
48    if entities:
49        data = __dict_replace(data, entities)
50    # must do ampersand last
51    return data.replace("&amp;", "&")
52
53def quoteattr(data, entities={}):
54    """Escape and quote an attribute value.
55
56    Escape &, <, and > in a string of data, then quote it for use as
57    an attribute value.  The \" character will be escaped as well, if
58    necessary.
59
60    You can escape other strings of data by passing a dictionary as
61    the optional entities parameter.  The keys and values must all be
62    strings; each key will be replaced with its corresponding value.
63    """
64    entities = entities.copy()
65    entities.update({'\n': '&#10;', '\r': '&#13;', '\t':'&#9;'})
66    data = escape(data, entities)
67    if '"' in data:
68        if "'" in data:
69            data = '"%s"' % data.replace('"', "&quot;")
70        else:
71            data = "'%s'" % data
72    else:
73        data = '"%s"' % data
74    return data
75
76
77def _gettextwriter(out, encoding):
78    if out is None:
79        import sys
80        out = sys.stdout
81
82    if isinstance(out, io.RawIOBase):
83        buffer = io.BufferedIOBase(out)
84        # Keep the original file open when the TextIOWrapper is
85        # destroyed
86        buffer.close = lambda: None
87    else:
88        # This is to handle passed objects that aren't in the
89        # IOBase hierarchy, but just have a write method
90        buffer = io.BufferedIOBase()
91        buffer.writable = lambda: True
92        buffer.write = out.write
93        try:
94            # TextIOWrapper uses this methods to determine
95            # if BOM (for UTF-16, etc) should be added
96            buffer.seekable = out.seekable
97            buffer.tell = out.tell
98        except AttributeError:
99            pass
100    # wrap a binary writer with TextIOWrapper
101    class UnbufferedTextIOWrapper(io.TextIOWrapper):
102        def write(self, s):
103            super(UnbufferedTextIOWrapper, self).write(s)
104            self.flush()
105    return UnbufferedTextIOWrapper(buffer, encoding=encoding,
106                                   errors='xmlcharrefreplace',
107                                   newline='\n')
108
109class XMLGenerator(handler.ContentHandler):
110
111    def __init__(self, out=None, encoding="iso-8859-1"):
112        handler.ContentHandler.__init__(self)
113        out = _gettextwriter(out, encoding)
114        self._write = out.write
115        self._flush = out.flush
116        self._ns_contexts = [{}] # contains uri -> prefix dicts
117        self._current_context = self._ns_contexts[-1]
118        self._undeclared_ns_maps = []
119        self._encoding = encoding
120
121    def _qname(self, name):
122        """Builds a qualified name from a (ns_url, localname) pair"""
123        if name[0]:
124            # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is
125            # bound by definition to http://www.w3.org/XML/1998/namespace.  It
126            # does not need to be declared and will not usually be found in
127            # self._current_context.
128            if 'http://www.w3.org/XML/1998/namespace' == name[0]:
129                return 'xml:' + name[1]
130            # The name is in a non-empty namespace
131            prefix = self._current_context[name[0]]
132            if prefix:
133                # If it is not the default namespace, prepend the prefix
134                return prefix + ":" + name[1]
135        # Return the unqualified name
136        return name[1]
137
138    # ContentHandler methods
139
140    def startDocument(self):
141        self._write(u'<?xml version="1.0" encoding="%s"?>\n' %
142                        self._encoding)
143
144    def endDocument(self):
145        self._flush()
146
147    def startPrefixMapping(self, prefix, uri):
148        self._ns_contexts.append(self._current_context.copy())
149        self._current_context[uri] = prefix
150        self._undeclared_ns_maps.append((prefix, uri))
151
152    def endPrefixMapping(self, prefix):
153        self._current_context = self._ns_contexts[-1]
154        del self._ns_contexts[-1]
155
156    def startElement(self, name, attrs):
157        self._write(u'<' + name)
158        for (name, value) in attrs.items():
159            self._write(u' %s=%s' % (name, quoteattr(value)))
160        self._write(u'>')
161
162    def endElement(self, name):
163        self._write(u'</%s>' % name)
164
165    def startElementNS(self, name, qname, attrs):
166        self._write(u'<' + self._qname(name))
167
168        for prefix, uri in self._undeclared_ns_maps:
169            if prefix:
170                self._write(u' xmlns:%s="%s"' % (prefix, uri))
171            else:
172                self._write(u' xmlns="%s"' % uri)
173        self._undeclared_ns_maps = []
174
175        for (name, value) in attrs.items():
176            self._write(u' %s=%s' % (self._qname(name), quoteattr(value)))
177        self._write(u'>')
178
179    def endElementNS(self, name, qname):
180        self._write(u'</%s>' % self._qname(name))
181
182    def characters(self, content):
183        self._write(escape(unicode(content)))
184
185    def ignorableWhitespace(self, content):
186        self._write(unicode(content))
187
188    def processingInstruction(self, target, data):
189        self._write(u'<?%s %s?>' % (target, data))
190
191
192class XMLFilterBase(xmlreader.XMLReader):
193    """This class is designed to sit between an XMLReader and the
194    client application's event handlers.  By default, it does nothing
195    but pass requests up to the reader and events on to the handlers
196    unmodified, but subclasses can override specific methods to modify
197    the event stream or the configuration requests as they pass
198    through."""
199
200    def __init__(self, parent = None):
201        xmlreader.XMLReader.__init__(self)
202        self._parent = parent
203
204    # ErrorHandler methods
205
206    def error(self, exception):
207        self._err_handler.error(exception)
208
209    def fatalError(self, exception):
210        self._err_handler.fatalError(exception)
211
212    def warning(self, exception):
213        self._err_handler.warning(exception)
214
215    # ContentHandler methods
216
217    def setDocumentLocator(self, locator):
218        self._cont_handler.setDocumentLocator(locator)
219
220    def startDocument(self):
221        self._cont_handler.startDocument()
222
223    def endDocument(self):
224        self._cont_handler.endDocument()
225
226    def startPrefixMapping(self, prefix, uri):
227        self._cont_handler.startPrefixMapping(prefix, uri)
228
229    def endPrefixMapping(self, prefix):
230        self._cont_handler.endPrefixMapping(prefix)
231
232    def startElement(self, name, attrs):
233        self._cont_handler.startElement(name, attrs)
234
235    def endElement(self, name):
236        self._cont_handler.endElement(name)
237
238    def startElementNS(self, name, qname, attrs):
239        self._cont_handler.startElementNS(name, qname, attrs)
240
241    def endElementNS(self, name, qname):
242        self._cont_handler.endElementNS(name, qname)
243
244    def characters(self, content):
245        self._cont_handler.characters(content)
246
247    def ignorableWhitespace(self, chars):
248        self._cont_handler.ignorableWhitespace(chars)
249
250    def processingInstruction(self, target, data):
251        self._cont_handler.processingInstruction(target, data)
252
253    def skippedEntity(self, name):
254        self._cont_handler.skippedEntity(name)
255
256    # DTDHandler methods
257
258    def notationDecl(self, name, publicId, systemId):
259        self._dtd_handler.notationDecl(name, publicId, systemId)
260
261    def unparsedEntityDecl(self, name, publicId, systemId, ndata):
262        self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata)
263
264    # EntityResolver methods
265
266    def resolveEntity(self, publicId, systemId):
267        return self._ent_handler.resolveEntity(publicId, systemId)
268
269    # XMLReader methods
270
271    def parse(self, source):
272        self._parent.setContentHandler(self)
273        self._parent.setErrorHandler(self)
274        self._parent.setEntityResolver(self)
275        self._parent.setDTDHandler(self)
276        self._parent.parse(source)
277
278    def setLocale(self, locale):
279        self._parent.setLocale(locale)
280
281    def getFeature(self, name):
282        return self._parent.getFeature(name)
283
284    def setFeature(self, name, state):
285        self._parent.setFeature(name, state)
286
287    def getProperty(self, name):
288        return self._parent.getProperty(name)
289
290    def setProperty(self, name, value):
291        self._parent.setProperty(name, value)
292
293    # XMLFilter methods
294
295    def getParent(self):
296        return self._parent
297
298    def setParent(self, parent):
299        self._parent = parent
300
301# --- Utility functions
302
303def prepare_input_source(source, base = ""):
304    """This function takes an InputSource and an optional base URL and
305    returns a fully resolved InputSource object ready for reading."""
306
307    if type(source) in _StringTypes:
308        source = xmlreader.InputSource(source)
309    elif hasattr(source, "read"):
310        f = source
311        source = xmlreader.InputSource()
312        source.setByteStream(f)
313        if hasattr(f, "name"):
314            source.setSystemId(f.name)
315
316    if source.getByteStream() is None:
317        try:
318            sysid = source.getSystemId()
319            basehead = os.path.dirname(os.path.normpath(base))
320            encoding = sys.getfilesystemencoding()
321            if isinstance(sysid, unicode):
322                if not isinstance(basehead, unicode):
323                    try:
324                        basehead = basehead.decode(encoding)
325                    except UnicodeDecodeError:
326                        sysid = sysid.encode(encoding)
327            else:
328                if isinstance(basehead, unicode):
329                    try:
330                        sysid = sysid.decode(encoding)
331                    except UnicodeDecodeError:
332                        basehead = basehead.encode(encoding)
333            sysidfilename = os.path.join(basehead, sysid)
334            isfile = os.path.isfile(sysidfilename)
335        except UnicodeError:
336            isfile = False
337        if isfile:
338            source.setSystemId(sysidfilename)
339            f = open(sysidfilename, "rb")
340        else:
341            source.setSystemId(urlparse.urljoin(base, source.getSystemId()))
342            f = urllib.urlopen(source.getSystemId())
343
344        source.setByteStream(f)
345
346    return source
347