1from __future__ import absolute_import, division, unicode_literals
2
3import os
4import sys
5import codecs
6import glob
7import xml.sax.handler
8
9base_path = os.path.split(__file__)[0]
10
11test_dir = os.path.join(base_path, 'testdata')
12sys.path.insert(0, os.path.abspath(os.path.join(base_path,
13                                                os.path.pardir,
14                                                os.path.pardir)))
15
16from html5lib import treebuilders
17del base_path
18
19# Build a dict of avaliable trees
20treeTypes = {"DOM": treebuilders.getTreeBuilder("dom")}
21
22# Try whatever etree implementations are avaliable from a list that are
23#"supposed" to work
24try:
25    import xml.etree.ElementTree as ElementTree
26    treeTypes['ElementTree'] = treebuilders.getTreeBuilder("etree", ElementTree, fullTree=True)
27except ImportError:
28    try:
29        import elementtree.ElementTree as ElementTree
30        treeTypes['ElementTree'] = treebuilders.getTreeBuilder("etree", ElementTree, fullTree=True)
31    except ImportError:
32        pass
33
34try:
35    import xml.etree.cElementTree as cElementTree
36    treeTypes['cElementTree'] = treebuilders.getTreeBuilder("etree", cElementTree, fullTree=True)
37except ImportError:
38    try:
39        import cElementTree
40        treeTypes['cElementTree'] = treebuilders.getTreeBuilder("etree", cElementTree, fullTree=True)
41    except ImportError:
42        pass
43
44try:
45    import lxml.etree as lxml  # flake8: noqa
46except ImportError:
47    pass
48else:
49    treeTypes['lxml'] = treebuilders.getTreeBuilder("lxml")
50
51
52def get_data_files(subdirectory, files='*.dat'):
53    return glob.glob(os.path.join(test_dir, subdirectory, files))
54
55
56class DefaultDict(dict):
57    def __init__(self, default, *args, **kwargs):
58        self.default = default
59        dict.__init__(self, *args, **kwargs)
60
61    def __getitem__(self, key):
62        return dict.get(self, key, self.default)
63
64
65class TestData(object):
66    def __init__(self, filename, newTestHeading="data", encoding="utf8"):
67        if encoding is None:
68            self.f = open(filename, mode="rb")
69        else:
70            self.f = codecs.open(filename, encoding=encoding)
71        self.encoding = encoding
72        self.newTestHeading = newTestHeading
73
74    def __del__(self):
75        self.f.close()
76
77    def __iter__(self):
78        data = DefaultDict(None)
79        key = None
80        for line in self.f:
81            heading = self.isSectionHeading(line)
82            if heading:
83                if data and heading == self.newTestHeading:
84                    # Remove trailing newline
85                    data[key] = data[key][:-1]
86                    yield self.normaliseOutput(data)
87                    data = DefaultDict(None)
88                key = heading
89                data[key] = "" if self.encoding else b""
90            elif key is not None:
91                data[key] += line
92        if data:
93            yield self.normaliseOutput(data)
94
95    def isSectionHeading(self, line):
96        """If the current heading is a test section heading return the heading,
97        otherwise return False"""
98        # print(line)
99        if line.startswith("#" if self.encoding else b"#"):
100            return line[1:].strip()
101        else:
102            return False
103
104    def normaliseOutput(self, data):
105        # Remove trailing newlines
106        for key, value in data.items():
107            if value.endswith("\n" if self.encoding else b"\n"):
108                data[key] = value[:-1]
109        return data
110
111
112def convert(stripChars):
113    def convertData(data):
114        """convert the output of str(document) to the format used in the testcases"""
115        data = data.split("\n")
116        rv = []
117        for line in data:
118            if line.startswith("|"):
119                rv.append(line[stripChars:])
120            else:
121                rv.append(line)
122        return "\n".join(rv)
123    return convertData
124
125convertExpected = convert(2)
126
127
128def errorMessage(input, expected, actual):
129    msg = ("Input:\n%s\nExpected:\n%s\nRecieved\n%s\n" %
130           (repr(input), repr(expected), repr(actual)))
131    if sys.version_info.major == 2:
132        msg = msg.encode("ascii", "backslashreplace")
133    return msg
134
135
136class TracingSaxHandler(xml.sax.handler.ContentHandler):
137    def __init__(self):
138        xml.sax.handler.ContentHandler.__init__(self)
139        self.visited = []
140
141    def startDocument(self):
142        self.visited.append('startDocument')
143
144    def endDocument(self):
145        self.visited.append('endDocument')
146
147    def startPrefixMapping(self, prefix, uri):
148        # These are ignored as their order is not guaranteed
149        pass
150
151    def endPrefixMapping(self, prefix):
152        # These are ignored as their order is not guaranteed
153        pass
154
155    def startElement(self, name, attrs):
156        self.visited.append(('startElement', name, attrs))
157
158    def endElement(self, name):
159        self.visited.append(('endElement', name))
160
161    def startElementNS(self, name, qname, attrs):
162        self.visited.append(('startElementNS', name, qname, dict(attrs)))
163
164    def endElementNS(self, name, qname):
165        self.visited.append(('endElementNS', name, qname))
166
167    def characters(self, content):
168        self.visited.append(('characters', content))
169
170    def ignorableWhitespace(self, whitespace):
171        self.visited.append(('ignorableWhitespace', whitespace))
172
173    def processingInstruction(self, target, data):
174        self.visited.append(('processingInstruction', target, data))
175
176    def skippedEntity(self, name):
177        self.visited.append(('skippedEntity', name))
178