1#!/usr/bin/python
2#
3# Copyright (C) 2009 Chia-I Wu <olv@0xlab.org>
4#
5# Permission is hereby granted, free of charge, to any person obtaining a
6# copy of this software and associated documentation files (the "Software"),
7# to deal in the Software without restriction, including without limitation
8# on the rights to use, copy, modify, merge, publish, distribute, sub
9# license, and/or sell copies of the Software, and to permit persons to whom
10# the Software is furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice (including the next
13# paragraph) shall be included in all copies or substantial portions of the
14# Software.
15#
16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.  IN NO EVENT SHALL
19# IBM AND/OR ITS SUPPLIERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22# IN THE SOFTWARE.
23"""
24A parser for APIspec.
25"""
26
27class SpecError(Exception):
28    """Error in the spec file."""
29
30
31class Spec(object):
32    """A Spec is an abstraction of the API spec."""
33
34    def __init__(self, doc):
35        self.doc = doc
36
37        self.spec_node = doc.getRootElement()
38        self.tmpl_nodes = {}
39        self.api_nodes = {}
40        self.impl_node = None
41
42        # parse <apispec>
43        node = self.spec_node.children
44        while node:
45            if node.type == "element":
46                if node.name == "template":
47                    self.tmpl_nodes[node.prop("name")] = node
48                elif node.name == "api":
49                    self.api_nodes[node.prop("name")] = node
50                else:
51                    raise SpecError("unexpected node %s in apispec" %
52                            node.name)
53            node = node.next
54
55        # find an implementation
56        for name, node in self.api_nodes.iteritems():
57            if node.prop("implementation") == "true":
58                self.impl_node = node
59                break
60        if not self.impl_node:
61            raise SpecError("unable to find an implementation")
62
63    def get_impl(self):
64        """Return the implementation."""
65        return API(self, self.impl_node)
66
67    def get_api(self, name):
68        """Return an API."""
69        return API(self, self.api_nodes[name])
70
71
72class API(object):
73    """An API consists of categories and functions."""
74
75    def __init__(self, spec, api_node):
76        self.name = api_node.prop("name")
77        self.is_impl = (api_node.prop("implementation") == "true")
78
79        self.categories = []
80        self.functions = []
81
82        # parse <api>
83        func_nodes = []
84        node = api_node.children
85        while node:
86            if node.type == "element":
87                if node.name == "category":
88                    cat = node.prop("name")
89                    self.categories.append(cat)
90                elif node.name == "function":
91                    func_nodes.append(node)
92                else:
93                    raise SpecError("unexpected node %s in api" % node.name)
94            node = node.next
95
96        # realize functions
97        for func_node in func_nodes:
98            tmpl_node = spec.tmpl_nodes[func_node.prop("template")]
99            try:
100                func = Function(tmpl_node, func_node, self.is_impl,
101                                self.categories)
102            except SpecError, e:
103                func_name = func_node.prop("name")
104                raise SpecError("failed to parse %s: %s" % (func_name, e))
105            self.functions.append(func)
106
107    def match(self, func, conversions={}):
108        """Find a matching function in the API."""
109        match = None
110        need_conv = False
111        for f in self.functions:
112            matched, conv = f.match(func, conversions)
113            if matched:
114                match = f
115                need_conv = conv
116                # exact match
117                if not need_conv:
118                    break
119        return (match, need_conv)
120
121
122class Function(object):
123    """Parse and realize a <template> node."""
124
125    def __init__(self, tmpl_node, func_node, force_skip_desc=False, categories=[]):
126        self.tmpl_name = tmpl_node.prop("name")
127        self.direction = tmpl_node.prop("direction")
128
129        self.name = func_node.prop("name")
130        self.prefix = func_node.prop("default_prefix")
131        self.is_external = (func_node.prop("external") == "true")
132
133        if force_skip_desc:
134            self._skip_desc = True
135        else:
136            self._skip_desc = (func_node.prop("skip_desc") == "true")
137
138        self._categories = categories
139
140        # these attributes decide how the template is realized
141        self._gltype = func_node.prop("gltype")
142        if func_node.hasProp("vector_size"):
143            self._vector_size = int(func_node.prop("vector_size"))
144        else:
145            self._vector_size = 0
146        self._expand_vector = (func_node.prop("expand_vector") == "true")
147
148        self.return_type = "void"
149        param_nodes = []
150
151        # find <proto>
152        proto_node = tmpl_node.children
153        while proto_node:
154            if proto_node.type == "element" and proto_node.name == "proto":
155                break
156            proto_node = proto_node.next
157        if not proto_node:
158            raise SpecError("no proto")
159        # and parse it
160        node = proto_node.children
161        while node:
162            if node.type == "element":
163                if node.name == "return":
164                    self.return_type = node.prop("type")
165                elif node.name == "param" or node.name == "vector":
166                    if self.support_node(node):
167                        # make sure the node is not hidden
168                        if not (self._expand_vector and
169                                (node.prop("hide_if_expanded") == "true")):
170                            param_nodes.append(node)
171                else:
172                    raise SpecError("unexpected node %s in proto" % node.name)
173            node = node.next
174
175        self._init_params(param_nodes)
176        self._init_descs(tmpl_node, param_nodes)
177
178    def __str__(self):
179        return "%s %s%s(%s)" % (self.return_type, self.prefix, self.name,
180                self.param_string(True))
181
182    def _init_params(self, param_nodes):
183        """Parse and initialize parameters."""
184        self.params = []
185
186        for param_node in param_nodes:
187            size = self.param_node_size(param_node)
188            # when no expansion, vector is just like param
189            if param_node.name == "param" or not self._expand_vector:
190                param = Parameter(param_node, self._gltype, size)
191                self.params.append(param)
192                continue
193
194            if not size or size > param_node.lsCountNode():
195                raise SpecError("could not expand %s with unknown or "
196                                "mismatch sizes" % param.name)
197
198            # expand the vector
199            expanded_params = []
200            child = param_node.children
201            while child:
202                if (child.type == "element" and child.name == "param" and
203                    self.support_node(child)):
204                    expanded_params.append(Parameter(child, self._gltype))
205                    if len(expanded_params) == size:
206                        break
207                child = child.next
208            # just in case that lsCountNode counts unknown nodes
209            if len(expanded_params) < size:
210                raise SpecError("not enough named parameters")
211
212            self.params.extend(expanded_params)
213
214    def _init_descs(self, tmpl_node, param_nodes):
215        """Parse and initialize parameter descriptions."""
216        self.checker = Checker()
217        if self._skip_desc:
218            return
219
220        node = tmpl_node.children
221        while node:
222            if node.type == "element" and node.name == "desc":
223                if self.support_node(node):
224                    # parse <desc>
225                    desc = Description(node, self._categories)
226                    self.checker.add_desc(desc)
227            node = node.next
228
229        self.checker.validate(self, param_nodes)
230
231    def support_node(self, node):
232        """Return true if a node is in the supported category."""
233        return (not node.hasProp("category") or
234                node.prop("category") in self._categories)
235
236    def get_param(self, name):
237        """Return the named parameter."""
238        for param in self.params:
239            if param.name == name:
240                return param
241        return None
242
243    def param_node_size(self, param):
244        """Return the size of a vector."""
245        if param.name != "vector":
246            return 0
247
248        size = param.prop("size")
249        if size.isdigit():
250            size = int(size)
251        else:
252            size = 0
253        if not size:
254            size = self._vector_size
255            if not size and self._expand_vector:
256                # return the number of named parameters
257                size = param.lsCountNode()
258        return size
259
260    def param_string(self, declaration):
261        """Return the C code of the parameters."""
262        args = []
263        if declaration:
264            for param in self.params:
265                sep = "" if param.type.endswith("*") else " "
266                args.append("%s%s%s" % (param.type, sep, param.name))
267            if not args:
268                args.append("void")
269        else:
270            for param in self.params:
271                args.append(param.name)
272        return ", ".join(args)
273
274    def match(self, other, conversions={}):
275        """Return true if the functions match, probably with a conversion."""
276        if (self.tmpl_name != other.tmpl_name or
277            self.return_type != other.return_type or
278            len(self.params) != len(other.params)):
279            return (False, False)
280
281        need_conv = False
282        for i in xrange(len(self.params)):
283            src = other.params[i]
284            dst = self.params[i]
285            if (src.is_vector != dst.is_vector or src.size != dst.size):
286                return (False, False)
287            if src.type != dst.type:
288                if dst.base_type() in conversions.get(src.base_type(), []):
289                    need_conv = True
290                else:
291                    # unable to convert
292                    return (False, False)
293
294        return (True, need_conv)
295
296
297class Parameter(object):
298    """A parameter of a function."""
299
300    def __init__(self, param_node, gltype=None, size=0):
301        self.is_vector = (param_node.name == "vector")
302
303        self.name = param_node.prop("name")
304        self.size = size
305
306        type = param_node.prop("type")
307        if gltype:
308            type = type.replace("GLtype", gltype)
309        elif type.find("GLtype") != -1:
310            raise SpecError("parameter %s has unresolved type" % self.name)
311
312        self.type = type
313
314    def base_type(self):
315        """Return the base GL type by stripping qualifiers."""
316        return [t for t in self.type.split(" ") if t.startswith("GL")][0]
317
318
319class Checker(object):
320    """A checker is the collection of all descriptions on the same level.
321    Descriptions of the same parameter are concatenated.
322    """
323
324    def __init__(self):
325        self.switches = {}
326        self.switch_constants = {}
327
328    def add_desc(self, desc):
329        """Add a description."""
330        # TODO allow index to vary
331        const_attrs = ["index", "error", "convert", "size_str"]
332        if desc.name not in self.switches:
333            self.switches[desc.name] = []
334            self.switch_constants[desc.name] = {}
335            for attr in const_attrs:
336                self.switch_constants[desc.name][attr] = None
337
338        # some attributes, like error code, should be the same for all descs
339        consts = self.switch_constants[desc.name]
340        for attr in const_attrs:
341            if getattr(desc, attr) is not None:
342                if (consts[attr] is not None and
343                    consts[attr] != getattr(desc, attr)):
344                    raise SpecError("mismatch %s for %s" % (attr, desc.name))
345                consts[attr] = getattr(desc, attr)
346
347        self.switches[desc.name].append(desc)
348
349    def validate(self, func, param_nodes):
350        """Validate the checker against a function."""
351        tmp = Checker()
352
353        for switch in self.switches.itervalues():
354            valid_descs = []
355            for desc in switch:
356                if desc.validate(func, param_nodes):
357                    valid_descs.append(desc)
358            # no possible values
359            if not valid_descs:
360                return False
361            for desc in valid_descs:
362                if not desc._is_noop:
363                    tmp.add_desc(desc)
364
365        self.switches = tmp.switches
366        self.switch_constants = tmp.switch_constants
367        return True
368
369    def flatten(self, name=None):
370        """Return a flat list of all descriptions of the named parameter."""
371        flat_list = []
372        for switch in self.switches.itervalues():
373            for desc in switch:
374                if not name or desc.name == name:
375                    flat_list.append(desc)
376                flat_list.extend(desc.checker.flatten(name))
377        return flat_list
378
379    def always_check(self, name):
380        """Return true if the parameter is checked in all possible pathes."""
381        if name in self.switches:
382            return True
383
384        # a param is always checked if any of the switch always checks it
385        for switch in self.switches.itervalues():
386            # a switch always checks it if all of the descs always check it
387            always = True
388            for desc in switch:
389                if not desc.checker.always_check(name):
390                    always = False
391                    break
392            if always:
393                return True
394        return False
395
396    def _c_switch(self, name, indent="\t"):
397        """Output C switch-statement for the named parameter, for debug."""
398        switch = self.switches.get(name, [])
399        # make sure there are valid values
400        need_switch = False
401        for desc in switch:
402            if desc.values:
403                need_switch = True
404        if not need_switch:
405            return []
406
407        stmts = []
408        var = switch[0].name
409        if switch[0].index >= 0:
410            var += "[%d]" % switch[0].index
411        stmts.append("switch (%s) { /* assume GLenum */" % var)
412
413        for desc in switch:
414            if desc.values:
415                for val in desc.values:
416                    stmts.append("case %s:" % val)
417                for dep_name in desc.checker.switches.iterkeys():
418                    dep_stmts = [indent + s for s in desc.checker._c_switch(dep_name, indent)]
419                    stmts.extend(dep_stmts)
420                stmts.append(indent + "break;")
421
422        stmts.append("default:")
423        stmts.append(indent + "ON_ERROR(%s);" % switch[0].error);
424        stmts.append(indent + "break;")
425        stmts.append("}")
426
427        return stmts
428
429    def dump(self, indent="\t"):
430        """Dump the descriptions in C code."""
431        stmts = []
432        for name in self.switches.iterkeys():
433            c_switch = self._c_switch(name)
434            print "\n".join(c_switch)
435
436
437class Description(object):
438    """A description desribes a parameter and its relationship with other
439    parameters.
440    """
441
442    def __init__(self, desc_node, categories=[]):
443        self._categories = categories
444        self._is_noop = False
445
446        self.name = desc_node.prop("name")
447        self.index = -1
448
449        self.error = desc_node.prop("error") or "GL_INVALID_ENUM"
450        # vector_size may be C code
451        self.size_str = desc_node.prop("vector_size")
452
453        self._has_enum = False
454        self.values = []
455        dep_nodes = []
456
457        # parse <desc>
458        valid_names = ["value", "range", "desc"]
459        node = desc_node.children
460        while node:
461            if node.type == "element":
462                if node.name in valid_names:
463                    # ignore nodes that require unsupported categories
464                    if (node.prop("category") and
465                        node.prop("category") not in self._categories):
466                        node = node.next
467                        continue
468                else:
469                    raise SpecError("unexpected node %s in desc" % node.name)
470
471                if node.name == "value":
472                    val = node.prop("name")
473                    if not self._has_enum and val.startswith("GL_"):
474                        self._has_enum = True
475                    self.values.append(val)
476                elif node.name == "range":
477                    first = int(node.prop("from"))
478                    last = int(node.prop("to"))
479                    base = node.prop("base") or ""
480                    if not self._has_enum and base.startswith("GL_"):
481                        self._has_enum = True
482                    # expand range
483                    for i in xrange(first, last + 1):
484                        self.values.append("%s%d" % (base, i))
485                else: # dependent desc
486                    dep_nodes.append(node)
487            node = node.next
488
489        # default to convert if there is no enum
490        self.convert = not self._has_enum
491        if desc_node.hasProp("convert"):
492            self.convert = (desc_node.prop("convert") == "true")
493
494        self._init_deps(dep_nodes)
495
496    def _init_deps(self, dep_nodes):
497        """Parse and initialize dependents."""
498        self.checker = Checker()
499
500        for dep_node in dep_nodes:
501            # recursion!
502            dep = Description(dep_node, self._categories)
503            self.checker.add_desc(dep)
504
505    def _search_param_node(self, param_nodes, name=None):
506        """Search the template parameters for the named node."""
507        param_node = None
508        param_index = -1
509
510        if not name:
511            name = self.name
512        for node in param_nodes:
513            if name == node.prop("name"):
514                param_node = node
515            elif node.name == "vector":
516                child = node.children
517                idx = 0
518                while child:
519                    if child.type == "element" and child.name == "param":
520                        if name == child.prop("name"):
521                            param_node = node
522                            param_index = idx
523                            break
524                        idx += 1
525                    child = child.next
526            if param_node:
527                break
528        return (param_node, param_index)
529
530    def _find_final(self, func, param_nodes):
531        """Find the final parameter."""
532        param = func.get_param(self.name)
533        param_index = -1
534
535        # the described param is not in the final function
536        if not param:
537            # search the template parameters
538            node, index = self._search_param_node(param_nodes)
539            if not node:
540                raise SpecError("invalid desc %s in %s" %
541                        (self.name, func.name))
542
543            # a named parameter of a vector
544            if index >= 0:
545                param = func.get_param(node.prop("name"))
546                param_index = index
547            elif node.name == "vector":
548                # must be an expanded vector, check its size
549                if self.size_str and self.size_str.isdigit():
550                    size = int(self.size_str)
551                    expanded_size = func.param_node_size(node)
552                    if size != expanded_size:
553                        return (False, None, -1)
554            # otherwise, it is a valid, but no-op, description
555
556        return (True, param, param_index)
557
558    def validate(self, func, param_nodes):
559        """Validate a description against certain function."""
560        if self.checker.switches and not self.values:
561            raise SpecError("no valid values for %s" % self.name)
562
563        valid, param, param_index = self._find_final(func, param_nodes)
564        if not valid:
565            return False
566
567        # the description is valid, but the param is gone
568        # mark it no-op so that it will be skipped
569        if not param:
570            self._is_noop = True
571            return True
572
573        if param.is_vector:
574            # if param was known, this should have been done in __init__
575            if self._has_enum:
576                self.size_str = "1"
577            # size mismatch
578            if (param.size and self.size_str and self.size_str.isdigit() and
579                param.size != int(self.size_str)):
580                return False
581        elif self.size_str:
582            # only vector accepts vector_size
583            raise SpecError("vector_size is invalid for %s" % param.name)
584
585        if not self.checker.validate(func, param_nodes):
586            return False
587
588        # update the description
589        self.name = param.name
590        self.index = param_index
591
592        return True
593
594
595def main():
596    import libxml2
597
598    filename = "APIspec.xml"
599    apinames = ["GLES1.1", "GLES2.0"]
600
601    doc = libxml2.readFile(filename, None,
602            libxml2.XML_PARSE_DTDLOAD +
603            libxml2.XML_PARSE_DTDVALID +
604            libxml2.XML_PARSE_NOBLANKS)
605
606    spec = Spec(doc)
607    impl = spec.get_impl()
608    for apiname in apinames:
609        spec.get_api(apiname)
610
611    doc.freeDoc()
612
613    print "%s is successfully parsed" % filename
614
615
616if __name__ == "__main__":
617    main()
618