1from __future__ import absolute_import, division, unicode_literals
2from six import text_type
3
4from lxml import etree
5from ..treebuilders.etree import tag_regexp
6
7from . import _base
8
9from .. import ihatexml
10
11
12def ensure_str(s):
13    if s is None:
14        return None
15    elif isinstance(s, text_type):
16        return s
17    else:
18        return s.decode("utf-8", "strict")
19
20
21class Root(object):
22    def __init__(self, et):
23        self.elementtree = et
24        self.children = []
25        if et.docinfo.internalDTD:
26            self.children.append(Doctype(self,
27                                         ensure_str(et.docinfo.root_name),
28                                         ensure_str(et.docinfo.public_id),
29                                         ensure_str(et.docinfo.system_url)))
30        root = et.getroot()
31        node = root
32
33        while node.getprevious() is not None:
34            node = node.getprevious()
35        while node is not None:
36            self.children.append(node)
37            node = node.getnext()
38
39        self.text = None
40        self.tail = None
41
42    def __getitem__(self, key):
43        return self.children[key]
44
45    def getnext(self):
46        return None
47
48    def __len__(self):
49        return 1
50
51
52class Doctype(object):
53    def __init__(self, root_node, name, public_id, system_id):
54        self.root_node = root_node
55        self.name = name
56        self.public_id = public_id
57        self.system_id = system_id
58
59        self.text = None
60        self.tail = None
61
62    def getnext(self):
63        return self.root_node.children[1]
64
65
66class FragmentRoot(Root):
67    def __init__(self, children):
68        self.children = [FragmentWrapper(self, child) for child in children]
69        self.text = self.tail = None
70
71    def getnext(self):
72        return None
73
74
75class FragmentWrapper(object):
76    def __init__(self, fragment_root, obj):
77        self.root_node = fragment_root
78        self.obj = obj
79        if hasattr(self.obj, 'text'):
80            self.text = ensure_str(self.obj.text)
81        else:
82            self.text = None
83        if hasattr(self.obj, 'tail'):
84            self.tail = ensure_str(self.obj.tail)
85        else:
86            self.tail = None
87
88    def __getattr__(self, name):
89        return getattr(self.obj, name)
90
91    def getnext(self):
92        siblings = self.root_node.children
93        idx = siblings.index(self)
94        if idx < len(siblings) - 1:
95            return siblings[idx + 1]
96        else:
97            return None
98
99    def __getitem__(self, key):
100        return self.obj[key]
101
102    def __bool__(self):
103        return bool(self.obj)
104
105    def getparent(self):
106        return None
107
108    def __str__(self):
109        return str(self.obj)
110
111    def __unicode__(self):
112        return str(self.obj)
113
114    def __len__(self):
115        return len(self.obj)
116
117
118class TreeWalker(_base.NonRecursiveTreeWalker):
119    def __init__(self, tree):
120        if hasattr(tree, "getroot"):
121            tree = Root(tree)
122        elif isinstance(tree, list):
123            tree = FragmentRoot(tree)
124        _base.NonRecursiveTreeWalker.__init__(self, tree)
125        self.filter = ihatexml.InfosetFilter()
126
127    def getNodeDetails(self, node):
128        if isinstance(node, tuple):  # Text node
129            node, key = node
130            assert key in ("text", "tail"), "Text nodes are text or tail, found %s" % key
131            return _base.TEXT, ensure_str(getattr(node, key))
132
133        elif isinstance(node, Root):
134            return (_base.DOCUMENT,)
135
136        elif isinstance(node, Doctype):
137            return _base.DOCTYPE, node.name, node.public_id, node.system_id
138
139        elif isinstance(node, FragmentWrapper) and not hasattr(node, "tag"):
140            return _base.TEXT, node.obj
141
142        elif node.tag == etree.Comment:
143            return _base.COMMENT, ensure_str(node.text)
144
145        elif node.tag == etree.Entity:
146            return _base.ENTITY, ensure_str(node.text)[1:-1]  # strip &;
147
148        else:
149            # This is assumed to be an ordinary element
150            match = tag_regexp.match(ensure_str(node.tag))
151            if match:
152                namespace, tag = match.groups()
153            else:
154                namespace = None
155                tag = ensure_str(node.tag)
156            attrs = {}
157            for name, value in list(node.attrib.items()):
158                name = ensure_str(name)
159                value = ensure_str(value)
160                match = tag_regexp.match(name)
161                if match:
162                    attrs[(match.group(1), match.group(2))] = value
163                else:
164                    attrs[(None, name)] = value
165            return (_base.ELEMENT, namespace, self.filter.fromXmlName(tag),
166                    attrs, len(node) > 0 or node.text)
167
168    def getFirstChild(self, node):
169        assert not isinstance(node, tuple), "Text nodes have no children"
170
171        assert len(node) or node.text, "Node has no children"
172        if node.text:
173            return (node, "text")
174        else:
175            return node[0]
176
177    def getNextSibling(self, node):
178        if isinstance(node, tuple):  # Text node
179            node, key = node
180            assert key in ("text", "tail"), "Text nodes are text or tail, found %s" % key
181            if key == "text":
182                # XXX: we cannot use a "bool(node) and node[0] or None" construct here
183                # because node[0] might evaluate to False if it has no child element
184                if len(node):
185                    return node[0]
186                else:
187                    return None
188            else:  # tail
189                return node.getnext()
190
191        return (node, "tail") if node.tail else node.getnext()
192
193    def getParentNode(self, node):
194        if isinstance(node, tuple):  # Text node
195            node, key = node
196            assert key in ("text", "tail"), "Text nodes are text or tail, found %s" % key
197            if key == "text":
198                return node
199            # else: fallback to "normal" processing
200
201        return node.getparent()
202