1# cython: infer_types=True
2
3#
4#   Tree visitor and transform framework
5#
6import inspect
7
8from Cython.Compiler import TypeSlots
9from Cython.Compiler import Builtin
10from Cython.Compiler import Nodes
11from Cython.Compiler import ExprNodes
12from Cython.Compiler import Errors
13from Cython.Compiler import DebugFlags
14
15import cython
16
17
18class TreeVisitor(object):
19    """
20    Base class for writing visitors for a Cython tree, contains utilities for
21    recursing such trees using visitors. Each node is
22    expected to have a child_attrs iterable containing the names of attributes
23    containing child nodes or lists of child nodes. Lists are not considered
24    part of the tree structure (i.e. contained nodes are considered direct
25    children of the parent node).
26
27    visit_children visits each of the children of a given node (see the visit_children
28    documentation). When recursing the tree using visit_children, an attribute
29    access_path is maintained which gives information about the current location
30    in the tree as a stack of tuples: (parent_node, attrname, index), representing
31    the node, attribute and optional list index that was taken in each step in the path to
32    the current node.
33
34    Example:
35
36    >>> class SampleNode(object):
37    ...     child_attrs = ["head", "body"]
38    ...     def __init__(self, value, head=None, body=None):
39    ...         self.value = value
40    ...         self.head = head
41    ...         self.body = body
42    ...     def __repr__(self): return "SampleNode(%s)" % self.value
43    ...
44    >>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)])
45    >>> class MyVisitor(TreeVisitor):
46    ...     def visit_SampleNode(self, node):
47    ...         print "in", node.value, self.access_path
48    ...         self.visitchildren(node)
49    ...         print "out", node.value
50    ...
51    >>> MyVisitor().visit(tree)
52    in 0 []
53    in 1 [(SampleNode(0), 'head', None)]
54    out 1
55    in 2 [(SampleNode(0), 'body', 0)]
56    out 2
57    in 3 [(SampleNode(0), 'body', 1)]
58    out 3
59    out 0
60    """
61    def __init__(self):
62        super(TreeVisitor, self).__init__()
63        self.dispatch_table = {}
64        self.access_path = []
65
66    def dump_node(self, node, indent=0):
67        ignored = list(node.child_attrs or []) + [u'child_attrs', u'pos',
68                                            u'gil_message', u'cpp_message',
69                                            u'subexprs']
70        values = []
71        pos = getattr(node, 'pos', None)
72        if pos:
73            source = pos[0]
74            if source:
75                import os.path
76                source = os.path.basename(source.get_description())
77            values.append(u'%s:%s:%s' % (source, pos[1], pos[2]))
78        attribute_names = dir(node)
79        attribute_names.sort()
80        for attr in attribute_names:
81            if attr in ignored:
82                continue
83            if attr.startswith(u'_') or attr.endswith(u'_'):
84                continue
85            try:
86                value = getattr(node, attr)
87            except AttributeError:
88                continue
89            if value is None or value == 0:
90                continue
91            elif isinstance(value, list):
92                value = u'[...]/%d' % len(value)
93            elif not isinstance(value, (str, unicode, long, int, float)):
94                continue
95            else:
96                value = repr(value)
97            values.append(u'%s = %s' % (attr, value))
98        return u'%s(%s)' % (node.__class__.__name__,
99                           u',\n    '.join(values))
100
101    def _find_node_path(self, stacktrace):
102        import os.path
103        last_traceback = stacktrace
104        nodes = []
105        while hasattr(stacktrace, 'tb_frame'):
106            frame = stacktrace.tb_frame
107            node = frame.f_locals.get(u'self')
108            if isinstance(node, Nodes.Node):
109                code = frame.f_code
110                method_name = code.co_name
111                pos = (os.path.basename(code.co_filename),
112                       frame.f_lineno)
113                nodes.append((node, method_name, pos))
114                last_traceback = stacktrace
115            stacktrace = stacktrace.tb_next
116        return (last_traceback, nodes)
117
118    def _raise_compiler_error(self, child, e):
119        import sys
120        trace = ['']
121        for parent, attribute, index in self.access_path:
122            node = getattr(parent, attribute)
123            if index is None:
124                index = ''
125            else:
126                node = node[index]
127                index = u'[%d]' % index
128            trace.append(u'%s.%s%s = %s' % (
129                parent.__class__.__name__, attribute, index,
130                self.dump_node(node)))
131        stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
132        last_node = child
133        for node, method_name, pos in called_nodes:
134            last_node = node
135            trace.append(u"File '%s', line %d, in %s: %s" % (
136                pos[0], pos[1], method_name, self.dump_node(node)))
137        raise Errors.CompilerCrash(
138            getattr(last_node, 'pos', None), self.__class__.__name__,
139            u'\n'.join(trace), e, stacktrace)
140
141    @cython.final
142    def find_handler(self, obj):
143        # to resolve, try entire hierarchy
144        cls = type(obj)
145        pattern = "visit_%s"
146        mro = inspect.getmro(cls)
147        handler_method = None
148        for mro_cls in mro:
149            handler_method = getattr(self, pattern % mro_cls.__name__, None)
150            if handler_method is not None:
151                return handler_method
152        print type(self), cls
153        if self.access_path:
154            print self.access_path
155            print self.access_path[-1][0].pos
156            print self.access_path[-1][0].__dict__
157        raise RuntimeError("Visitor %r does not accept object: %s" % (self, obj))
158
159    def visit(self, obj):
160        return self._visit(obj)
161
162    @cython.final
163    def _visit(self, obj):
164        try:
165            try:
166                handler_method = self.dispatch_table[type(obj)]
167            except KeyError:
168                handler_method = self.find_handler(obj)
169                self.dispatch_table[type(obj)] = handler_method
170            return handler_method(obj)
171        except Errors.CompileError:
172            raise
173        except Errors.AbortError:
174            raise
175        except Exception, e:
176            if DebugFlags.debug_no_exception_intercept:
177                raise
178            self._raise_compiler_error(obj, e)
179
180    @cython.final
181    def _visitchild(self, child, parent, attrname, idx):
182        self.access_path.append((parent, attrname, idx))
183        result = self._visit(child)
184        self.access_path.pop()
185        return result
186
187    def visitchildren(self, parent, attrs=None):
188        return self._visitchildren(parent, attrs)
189
190    @cython.final
191    @cython.locals(idx=int)
192    def _visitchildren(self, parent, attrs):
193        """
194        Visits the children of the given parent. If parent is None, returns
195        immediately (returning None).
196
197        The return value is a dictionary giving the results for each
198        child (mapping the attribute name to either the return value
199        or a list of return values (in the case of multiple children
200        in an attribute)).
201        """
202        if parent is None: return None
203        result = {}
204        for attr in parent.child_attrs:
205            if attrs is not None and attr not in attrs: continue
206            child = getattr(parent, attr)
207            if child is not None:
208                if type(child) is list:
209                    childretval = [self._visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
210                else:
211                    childretval = self._visitchild(child, parent, attr, None)
212                    assert not isinstance(childretval, list), 'Cannot insert list here: %s in %r' % (attr, parent)
213                result[attr] = childretval
214        return result
215
216
217class VisitorTransform(TreeVisitor):
218    """
219    A tree transform is a base class for visitors that wants to do stream
220    processing of the structure (rather than attributes etc.) of a tree.
221
222    It implements __call__ to simply visit the argument node.
223
224    It requires the visitor methods to return the nodes which should take
225    the place of the visited node in the result tree (which can be the same
226    or one or more replacement). Specifically, if the return value from
227    a visitor method is:
228
229    - [] or None; the visited node will be removed (set to None if an attribute and
230    removed if in a list)
231    - A single node; the visited node will be replaced by the returned node.
232    - A list of nodes; the visited nodes will be replaced by all the nodes in the
233    list. This will only work if the node was already a member of a list; if it
234    was not, an exception will be raised. (Typically you want to ensure that you
235    are within a StatListNode or similar before doing this.)
236    """
237    def visitchildren(self, parent, attrs=None):
238        result = self._visitchildren(parent, attrs)
239        for attr, newnode in result.iteritems():
240            if type(newnode) is not list:
241                setattr(parent, attr, newnode)
242            else:
243                # Flatten the list one level and remove any None
244                newlist = []
245                for x in newnode:
246                    if x is not None:
247                        if type(x) is list:
248                            newlist += x
249                        else:
250                            newlist.append(x)
251                setattr(parent, attr, newlist)
252        return result
253
254    def recurse_to_children(self, node):
255        self.visitchildren(node)
256        return node
257
258    def __call__(self, root):
259        return self._visit(root)
260
261class CythonTransform(VisitorTransform):
262    """
263    Certain common conventions and utilities for Cython transforms.
264
265     - Sets up the context of the pipeline in self.context
266     - Tracks directives in effect in self.current_directives
267    """
268    def __init__(self, context):
269        super(CythonTransform, self).__init__()
270        self.context = context
271
272    def __call__(self, node):
273        import ModuleNode
274        if isinstance(node, ModuleNode.ModuleNode):
275            self.current_directives = node.directives
276        return super(CythonTransform, self).__call__(node)
277
278    def visit_CompilerDirectivesNode(self, node):
279        old = self.current_directives
280        self.current_directives = node.directives
281        self.visitchildren(node)
282        self.current_directives = old
283        return node
284
285    def visit_Node(self, node):
286        self.visitchildren(node)
287        return node
288
289class ScopeTrackingTransform(CythonTransform):
290    # Keeps track of type of scopes
291    #scope_type: can be either of 'module', 'function', 'cclass', 'pyclass', 'struct'
292    #scope_node: the node that owns the current scope
293
294    def visit_ModuleNode(self, node):
295        self.scope_type = 'module'
296        self.scope_node = node
297        self.visitchildren(node)
298        return node
299
300    def visit_scope(self, node, scope_type):
301        prev = self.scope_type, self.scope_node
302        self.scope_type = scope_type
303        self.scope_node = node
304        self.visitchildren(node)
305        self.scope_type, self.scope_node = prev
306        return node
307
308    def visit_CClassDefNode(self, node):
309        return self.visit_scope(node, 'cclass')
310
311    def visit_PyClassDefNode(self, node):
312        return self.visit_scope(node, 'pyclass')
313
314    def visit_FuncDefNode(self, node):
315        return self.visit_scope(node, 'function')
316
317    def visit_CStructOrUnionDefNode(self, node):
318        return self.visit_scope(node, 'struct')
319
320
321class EnvTransform(CythonTransform):
322    """
323    This transformation keeps a stack of the environments.
324    """
325    def __call__(self, root):
326        self.env_stack = []
327        self.enter_scope(root, root.scope)
328        return super(EnvTransform, self).__call__(root)
329
330    def current_env(self):
331        return self.env_stack[-1][1]
332
333    def current_scope_node(self):
334        return self.env_stack[-1][0]
335
336    def global_scope(self):
337        return self.current_env().global_scope()
338
339    def enter_scope(self, node, scope):
340        self.env_stack.append((node, scope))
341
342    def exit_scope(self):
343        self.env_stack.pop()
344
345    def visit_FuncDefNode(self, node):
346        self.enter_scope(node, node.local_scope)
347        self.visitchildren(node)
348        self.exit_scope()
349        return node
350
351    def visit_GeneratorBodyDefNode(self, node):
352        self.visitchildren(node)
353        return node
354
355    def visit_ClassDefNode(self, node):
356        self.enter_scope(node, node.scope)
357        self.visitchildren(node)
358        self.exit_scope()
359        return node
360
361    def visit_CStructOrUnionDefNode(self, node):
362        self.enter_scope(node, node.scope)
363        self.visitchildren(node)
364        self.exit_scope()
365        return node
366
367    def visit_ScopedExprNode(self, node):
368        if node.expr_scope:
369            self.enter_scope(node, node.expr_scope)
370            self.visitchildren(node)
371            self.exit_scope()
372        else:
373            self.visitchildren(node)
374        return node
375
376    def visit_CArgDeclNode(self, node):
377        # default arguments are evaluated in the outer scope
378        if node.default:
379            attrs = [ attr for attr in node.child_attrs if attr != 'default' ]
380            self.visitchildren(node, attrs)
381            self.enter_scope(node, self.current_env().outer_scope)
382            self.visitchildren(node, ('default',))
383            self.exit_scope()
384        else:
385            self.visitchildren(node)
386        return node
387
388
389class NodeRefCleanupMixin(object):
390    """
391    Clean up references to nodes that were replaced.
392
393    NOTE: this implementation assumes that the replacement is
394    done first, before hitting any further references during
395    normal tree traversal.  This needs to be arranged by calling
396    "self.visitchildren()" at a proper place in the transform
397    and by ordering the "child_attrs" of nodes appropriately.
398    """
399    def __init__(self, *args):
400        super(NodeRefCleanupMixin, self).__init__(*args)
401        self._replacements = {}
402
403    def visit_CloneNode(self, node):
404        arg = node.arg
405        if arg not in self._replacements:
406            self.visitchildren(node)
407            arg = node.arg
408        node.arg = self._replacements.get(arg, arg)
409        return node
410
411    def visit_ResultRefNode(self, node):
412        expr = node.expression
413        if expr is None or expr not in self._replacements:
414            self.visitchildren(node)
415            expr = node.expression
416        if expr is not None:
417            node.expression = self._replacements.get(expr, expr)
418        return node
419
420    def replace(self, node, replacement):
421        self._replacements[node] = replacement
422        return replacement
423
424
425find_special_method_for_binary_operator = {
426    '<':  '__lt__',
427    '<=': '__le__',
428    '==': '__eq__',
429    '!=': '__ne__',
430    '>=': '__ge__',
431    '>':  '__gt__',
432    '+':  '__add__',
433    '&':  '__and__',
434    '/':  '__truediv__',
435    '//': '__floordiv__',
436    '<<': '__lshift__',
437    '%':  '__mod__',
438    '*':  '__mul__',
439    '|':  '__or__',
440    '**': '__pow__',
441    '>>': '__rshift__',
442    '-':  '__sub__',
443    '^':  '__xor__',
444    'in': '__contains__',
445}.get
446
447
448find_special_method_for_unary_operator = {
449    'not': '__not__',
450    '~':   '__inv__',
451    '-':   '__neg__',
452    '+':   '__pos__',
453}.get
454
455
456class MethodDispatcherTransform(EnvTransform):
457    """
458    Base class for transformations that want to intercept on specific
459    builtin functions or methods of builtin types, including special
460    methods triggered by Python operators.  Must run after declaration
461    analysis when entries were assigned.
462
463    Naming pattern for handler methods is as follows:
464
465    * builtin functions: _handle_(general|simple|any)_function_NAME
466
467    * builtin methods: _handle_(general|simple|any)_method_TYPENAME_METHODNAME
468    """
469    # only visit call nodes and Python operations
470    def visit_GeneralCallNode(self, node):
471        self.visitchildren(node)
472        function = node.function
473        if not function.type.is_pyobject:
474            return node
475        arg_tuple = node.positional_args
476        if not isinstance(arg_tuple, ExprNodes.TupleNode):
477            return node
478        keyword_args = node.keyword_args
479        if keyword_args and not isinstance(keyword_args, ExprNodes.DictNode):
480            # can't handle **kwargs
481            return node
482        args = arg_tuple.args
483        return self._dispatch_to_handler(node, function, args, keyword_args)
484
485    def visit_SimpleCallNode(self, node):
486        self.visitchildren(node)
487        function = node.function
488        if function.type.is_pyobject:
489            arg_tuple = node.arg_tuple
490            if not isinstance(arg_tuple, ExprNodes.TupleNode):
491                return node
492            args = arg_tuple.args
493        else:
494            args = node.args
495        return self._dispatch_to_handler(node, function, args, None)
496
497    def visit_PrimaryCmpNode(self, node):
498        if node.cascade:
499            # not currently handled below
500            self.visitchildren(node)
501            return node
502        return self._visit_binop_node(node)
503
504    def visit_BinopNode(self, node):
505        return self._visit_binop_node(node)
506
507    def _visit_binop_node(self, node):
508        self.visitchildren(node)
509        # FIXME: could special case 'not_in'
510        special_method_name = find_special_method_for_binary_operator(node.operator)
511        if special_method_name:
512            operand1, operand2 = node.operand1, node.operand2
513            if special_method_name == '__contains__':
514                operand1, operand2 = operand2, operand1
515            obj_type = operand1.type
516            if obj_type.is_builtin_type:
517                type_name = obj_type.name
518            else:
519                type_name = "object"  # safety measure
520            node = self._dispatch_to_method_handler(
521                special_method_name, None, False, type_name,
522                node, None, [operand1, operand2], None)
523        return node
524
525    def visit_UnopNode(self, node):
526        self.visitchildren(node)
527        special_method_name = find_special_method_for_unary_operator(node.operator)
528        if special_method_name:
529            operand = node.operand
530            obj_type = operand.type
531            if obj_type.is_builtin_type:
532                type_name = obj_type.name
533            else:
534                type_name = "object"  # safety measure
535            node = self._dispatch_to_method_handler(
536                special_method_name, None, False, type_name,
537                node, None, [operand], None)
538        return node
539
540    ### dispatch to specific handlers
541
542    def _find_handler(self, match_name, has_kwargs):
543        call_type = has_kwargs and 'general' or 'simple'
544        handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
545        if handler is None:
546            handler = getattr(self, '_handle_any_%s' % match_name, None)
547        return handler
548
549    def _delegate_to_assigned_value(self, node, function, arg_list, kwargs):
550        assignment = function.cf_state[0]
551        value = assignment.rhs
552        if value.is_name:
553            if not value.entry or len(value.entry.cf_assignments) > 1:
554                # the variable might have been reassigned => play safe
555                return node
556        elif value.is_attribute and value.obj.is_name:
557            if not value.obj.entry or len(value.obj.entry.cf_assignments) > 1:
558                # the underlying variable might have been reassigned => play safe
559                return node
560        else:
561            return node
562        return self._dispatch_to_handler(
563            node, value, arg_list, kwargs)
564
565    def _dispatch_to_handler(self, node, function, arg_list, kwargs):
566        if function.is_name:
567            # we only consider functions that are either builtin
568            # Python functions or builtins that were already replaced
569            # into a C function call (defined in the builtin scope)
570            if not function.entry:
571                return node
572            is_builtin = (
573                function.entry.is_builtin or
574                function.entry is self.current_env().builtin_scope().lookup_here(function.name))
575            if not is_builtin:
576                if function.cf_state and function.cf_state.is_single:
577                    # we know the value of the variable
578                    # => see if it's usable instead
579                    return self._delegate_to_assigned_value(
580                        node, function, arg_list, kwargs)
581                return node
582            function_handler = self._find_handler(
583                "function_%s" % function.name, kwargs)
584            if function_handler is None:
585                return self._handle_function(node, function.name, function, arg_list, kwargs)
586            if kwargs:
587                return function_handler(node, function, arg_list, kwargs)
588            else:
589                return function_handler(node, function, arg_list)
590        elif function.is_attribute and function.type.is_pyobject:
591            attr_name = function.attribute
592            self_arg = function.obj
593            obj_type = self_arg.type
594            is_unbound_method = False
595            if obj_type.is_builtin_type:
596                if (obj_type is Builtin.type_type and self_arg.is_name and
597                        arg_list and arg_list[0].type.is_pyobject):
598                    # calling an unbound method like 'list.append(L,x)'
599                    # (ignoring 'type.mro()' here ...)
600                    type_name = self_arg.name
601                    self_arg = None
602                    is_unbound_method = True
603                else:
604                    type_name = obj_type.name
605            else:
606                type_name = "object"  # safety measure
607            return self._dispatch_to_method_handler(
608                attr_name, self_arg, is_unbound_method, type_name,
609                node, function, arg_list, kwargs)
610        else:
611            return node
612
613    def _dispatch_to_method_handler(self, attr_name, self_arg,
614                                    is_unbound_method, type_name,
615                                    node, function, arg_list, kwargs):
616        method_handler = self._find_handler(
617            "method_%s_%s" % (type_name, attr_name), kwargs)
618        if method_handler is None:
619            if (attr_name in TypeSlots.method_name_to_slot
620                    or attr_name == '__new__'):
621                method_handler = self._find_handler(
622                    "slot%s" % attr_name, kwargs)
623            if method_handler is None:
624                return self._handle_method(
625                    node, type_name, attr_name, function,
626                    arg_list, is_unbound_method, kwargs)
627        if self_arg is not None:
628            arg_list = [self_arg] + list(arg_list)
629        if kwargs:
630            return method_handler(
631                node, function, arg_list, is_unbound_method, kwargs)
632        else:
633            return method_handler(
634                node, function, arg_list, is_unbound_method)
635
636    def _handle_function(self, node, function_name, function, arg_list, kwargs):
637        """Fallback handler"""
638        return node
639
640    def _handle_method(self, node, type_name, attr_name, function,
641                       arg_list, is_unbound_method, kwargs):
642        """Fallback handler"""
643        return node
644
645
646class RecursiveNodeReplacer(VisitorTransform):
647    """
648    Recursively replace all occurrences of a node in a subtree by
649    another node.
650    """
651    def __init__(self, orig_node, new_node):
652        super(RecursiveNodeReplacer, self).__init__()
653        self.orig_node, self.new_node = orig_node, new_node
654
655    def visit_Node(self, node):
656        self.visitchildren(node)
657        if node is self.orig_node:
658            return self.new_node
659        else:
660            return node
661
662def recursively_replace_node(tree, old_node, new_node):
663    replace_in = RecursiveNodeReplacer(old_node, new_node)
664    replace_in(tree)
665
666
667class NodeFinder(TreeVisitor):
668    """
669    Find out if a node appears in a subtree.
670    """
671    def __init__(self, node):
672        super(NodeFinder, self).__init__()
673        self.node = node
674        self.found = False
675
676    def visit_Node(self, node):
677        if self.found:
678            pass  # short-circuit
679        elif node is self.node:
680            self.found = True
681        else:
682            self._visitchildren(node, None)
683
684def tree_contains(tree, node):
685    finder = NodeFinder(node)
686    finder.visit(tree)
687    return finder.found
688
689
690# Utils
691def replace_node(ptr, value):
692    """Replaces a node. ptr is of the form used on the access path stack
693    (parent, attrname, listidx|None)
694    """
695    parent, attrname, listidx = ptr
696    if listidx is None:
697        setattr(parent, attrname, value)
698    else:
699        getattr(parent, attrname)[listidx] = value
700
701class PrintTree(TreeVisitor):
702    """Prints a representation of the tree to standard output.
703    Subclass and override repr_of to provide more information
704    about nodes. """
705    def __init__(self):
706        TreeVisitor.__init__(self)
707        self._indent = ""
708
709    def indent(self):
710        self._indent += "  "
711    def unindent(self):
712        self._indent = self._indent[:-2]
713
714    def __call__(self, tree, phase=None):
715        print("Parse tree dump at phase '%s'" % phase)
716        self.visit(tree)
717        return tree
718
719    # Don't do anything about process_list, the defaults gives
720    # nice-looking name[idx] nodes which will visually appear
721    # under the parent-node, not displaying the list itself in
722    # the hierarchy.
723    def visit_Node(self, node):
724        if len(self.access_path) == 0:
725            name = "(root)"
726        else:
727            parent, attr, idx = self.access_path[-1]
728            if idx is not None:
729                name = "%s[%d]" % (attr, idx)
730            else:
731                name = attr
732        print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
733        self.indent()
734        self.visitchildren(node)
735        self.unindent()
736        return node
737
738    def repr_of(self, node):
739        if node is None:
740            return "(none)"
741        else:
742            result = node.__class__.__name__
743            if isinstance(node, ExprNodes.NameNode):
744                result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
745            elif isinstance(node, Nodes.DefNode):
746                result += "(name=\"%s\")" % node.name
747            elif isinstance(node, ExprNodes.ExprNode):
748                t = node.type
749                result += "(type=%s)" % repr(t)
750            elif node.pos:
751                pos = node.pos
752                path = pos[0].get_description()
753                if '/' in path:
754                    path = path.split('/')[-1]
755                if '\\' in path:
756                    path = path.split('\\')[-1]
757                result += "(pos=(%s:%s:%s))" % (path, pos[1], pos[2])
758
759            return result
760
761if __name__ == "__main__":
762    import doctest
763    doctest.testmod()
764