1__all__ = [
2    'LXMLTreeBuilderForXML',
3    'LXMLTreeBuilder',
4    ]
5
6from io import BytesIO
7from StringIO import StringIO
8import collections
9from lxml import etree
10from bs4.element import Comment, Doctype, NamespacedAttribute
11from bs4.builder import (
12    FAST,
13    HTML,
14    HTMLTreeBuilder,
15    PERMISSIVE,
16    ParserRejectedMarkup,
17    TreeBuilder,
18    XML)
19from bs4.dammit import EncodingDetector
20
21LXML = 'lxml'
22
23class LXMLTreeBuilderForXML(TreeBuilder):
24    DEFAULT_PARSER_CLASS = etree.XMLParser
25
26    is_xml = True
27
28    # Well, it's permissive by XML parser standards.
29    features = [LXML, XML, FAST, PERMISSIVE]
30
31    CHUNK_SIZE = 512
32
33    # This namespace mapping is specified in the XML Namespace
34    # standard.
35    DEFAULT_NSMAPS = {'http://www.w3.org/XML/1998/namespace' : "xml"}
36
37    def default_parser(self, encoding):
38        # This can either return a parser object or a class, which
39        # will be instantiated with default arguments.
40        if self._default_parser is not None:
41            return self._default_parser
42        return etree.XMLParser(
43            target=self, strip_cdata=False, recover=True, encoding=encoding)
44
45    def parser_for(self, encoding):
46        # Use the default parser.
47        parser = self.default_parser(encoding)
48
49        if isinstance(parser, collections.Callable):
50            # Instantiate the parser with default arguments
51            parser = parser(target=self, strip_cdata=False, encoding=encoding)
52        return parser
53
54    def __init__(self, parser=None, empty_element_tags=None):
55        # TODO: Issue a warning if parser is present but not a
56        # callable, since that means there's no way to create new
57        # parsers for different encodings.
58        self._default_parser = parser
59        if empty_element_tags is not None:
60            self.empty_element_tags = set(empty_element_tags)
61        self.soup = None
62        self.nsmaps = [self.DEFAULT_NSMAPS]
63
64    def _getNsTag(self, tag):
65        # Split the namespace URL out of a fully-qualified lxml tag
66        # name. Copied from lxml's src/lxml/sax.py.
67        if tag[0] == '{':
68            return tuple(tag[1:].split('}', 1))
69        else:
70            return (None, tag)
71
72    def prepare_markup(self, markup, user_specified_encoding=None,
73                       document_declared_encoding=None):
74        """
75        :yield: A series of 4-tuples.
76         (markup, encoding, declared encoding,
77          has undergone character replacement)
78
79        Each 4-tuple represents a strategy for parsing the document.
80        """
81        if isinstance(markup, unicode):
82            # We were given Unicode. Maybe lxml can parse Unicode on
83            # this system?
84            yield markup, None, document_declared_encoding, False
85
86        if isinstance(markup, unicode):
87            # No, apparently not. Convert the Unicode to UTF-8 and
88            # tell lxml to parse it as UTF-8.
89            yield (markup.encode("utf8"), "utf8",
90                   document_declared_encoding, False)
91
92        # Instead of using UnicodeDammit to convert the bytestring to
93        # Unicode using different encodings, use EncodingDetector to
94        # iterate over the encodings, and tell lxml to try to parse
95        # the document as each one in turn.
96        is_html = not self.is_xml
97        try_encodings = [user_specified_encoding, document_declared_encoding]
98        detector = EncodingDetector(markup, try_encodings, is_html)
99        for encoding in detector.encodings:
100            yield (detector.markup, encoding, document_declared_encoding, False)
101
102    def feed(self, markup):
103        if isinstance(markup, bytes):
104            markup = BytesIO(markup)
105        elif isinstance(markup, unicode):
106            markup = StringIO(markup)
107
108        # Call feed() at least once, even if the markup is empty,
109        # or the parser won't be initialized.
110        data = markup.read(self.CHUNK_SIZE)
111        try:
112            self.parser = self.parser_for(self.soup.original_encoding)
113            self.parser.feed(data)
114            while len(data) != 0:
115                # Now call feed() on the rest of the data, chunk by chunk.
116                data = markup.read(self.CHUNK_SIZE)
117                if len(data) != 0:
118                    self.parser.feed(data)
119            self.parser.close()
120        except (UnicodeDecodeError, LookupError, etree.ParserError), e:
121            raise ParserRejectedMarkup(str(e))
122
123    def close(self):
124        self.nsmaps = [self.DEFAULT_NSMAPS]
125
126    def start(self, name, attrs, nsmap={}):
127        # Make sure attrs is a mutable dict--lxml may send an immutable dictproxy.
128        attrs = dict(attrs)
129        nsprefix = None
130        # Invert each namespace map as it comes in.
131        if len(self.nsmaps) > 1:
132            # There are no new namespaces for this tag, but
133            # non-default namespaces are in play, so we need a
134            # separate tag stack to know when they end.
135            self.nsmaps.append(None)
136        elif len(nsmap) > 0:
137            # A new namespace mapping has come into play.
138            inverted_nsmap = dict((value, key) for key, value in nsmap.items())
139            self.nsmaps.append(inverted_nsmap)
140            # Also treat the namespace mapping as a set of attributes on the
141            # tag, so we can recreate it later.
142            attrs = attrs.copy()
143            for prefix, namespace in nsmap.items():
144                attribute = NamespacedAttribute(
145                    "xmlns", prefix, "http://www.w3.org/2000/xmlns/")
146                attrs[attribute] = namespace
147
148        # Namespaces are in play. Find any attributes that came in
149        # from lxml with namespaces attached to their names, and
150        # turn then into NamespacedAttribute objects.
151        new_attrs = {}
152        for attr, value in attrs.items():
153            namespace, attr = self._getNsTag(attr)
154            if namespace is None:
155                new_attrs[attr] = value
156            else:
157                nsprefix = self._prefix_for_namespace(namespace)
158                attr = NamespacedAttribute(nsprefix, attr, namespace)
159                new_attrs[attr] = value
160        attrs = new_attrs
161
162        namespace, name = self._getNsTag(name)
163        nsprefix = self._prefix_for_namespace(namespace)
164        self.soup.handle_starttag(name, namespace, nsprefix, attrs)
165
166    def _prefix_for_namespace(self, namespace):
167        """Find the currently active prefix for the given namespace."""
168        if namespace is None:
169            return None
170        for inverted_nsmap in reversed(self.nsmaps):
171            if inverted_nsmap is not None and namespace in inverted_nsmap:
172                return inverted_nsmap[namespace]
173        return None
174
175    def end(self, name):
176        self.soup.endData()
177        completed_tag = self.soup.tagStack[-1]
178        namespace, name = self._getNsTag(name)
179        nsprefix = None
180        if namespace is not None:
181            for inverted_nsmap in reversed(self.nsmaps):
182                if inverted_nsmap is not None and namespace in inverted_nsmap:
183                    nsprefix = inverted_nsmap[namespace]
184                    break
185        self.soup.handle_endtag(name, nsprefix)
186        if len(self.nsmaps) > 1:
187            # This tag, or one of its parents, introduced a namespace
188            # mapping, so pop it off the stack.
189            self.nsmaps.pop()
190
191    def pi(self, target, data):
192        pass
193
194    def data(self, content):
195        self.soup.handle_data(content)
196
197    def doctype(self, name, pubid, system):
198        self.soup.endData()
199        doctype = Doctype.for_name_and_ids(name, pubid, system)
200        self.soup.object_was_parsed(doctype)
201
202    def comment(self, content):
203        "Handle comments as Comment objects."
204        self.soup.endData()
205        self.soup.handle_data(content)
206        self.soup.endData(Comment)
207
208    def test_fragment_to_document(self, fragment):
209        """See `TreeBuilder`."""
210        return u'<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment
211
212
213class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
214
215    features = [LXML, HTML, FAST, PERMISSIVE]
216    is_xml = False
217
218    def default_parser(self, encoding):
219        return etree.HTMLParser
220
221    def feed(self, markup):
222        encoding = self.soup.original_encoding
223        try:
224            self.parser = self.parser_for(encoding)
225            self.parser.feed(markup)
226            self.parser.close()
227        except (UnicodeDecodeError, LookupError, etree.ParserError), e:
228            raise ParserRejectedMarkup(str(e))
229
230
231    def test_fragment_to_document(self, fragment):
232        """See `TreeBuilder`."""
233        return u'<html><body>%s</body></html>' % fragment
234