1# mako/_ast_util.py
2# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
3#
4# This module is part of Mako and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""
8    ast
9    ~~~
10
11    The `ast` module helps Python applications to process trees of the Python
12    abstract syntax grammar.  The abstract syntax itself might change with
13    each Python release; this module helps to find out programmatically what
14    the current grammar looks like and allows modifications of it.
15
16    An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
17    a flag to the `compile()` builtin function or by using the `parse()`
18    function from this module.  The result will be a tree of objects whose
19    classes all inherit from `ast.AST`.
20
21    A modified abstract syntax tree can be compiled into a Python code object
22    using the built-in `compile()` function.
23
24    Additionally various helper functions are provided that make working with
25    the trees simpler.  The main intention of the helper functions and this
26    module in general is to provide an easy to use interface for libraries
27    that work tightly with the python syntax (template engines for example).
28
29
30    :copyright: Copyright 2008 by Armin Ronacher.
31    :license: Python License.
32"""
33from _ast import *
34from mako.compat import arg_stringname
35
36BOOLOP_SYMBOLS = {
37    And: 'and',
38    Or: 'or'
39}
40
41BINOP_SYMBOLS = {
42    Add: '+',
43    Sub: '-',
44    Mult: '*',
45    Div: '/',
46    FloorDiv: '//',
47    Mod: '%',
48    LShift: '<<',
49    RShift: '>>',
50    BitOr: '|',
51    BitAnd: '&',
52    BitXor: '^'
53}
54
55CMPOP_SYMBOLS = {
56    Eq: '==',
57    Gt: '>',
58    GtE: '>=',
59    In: 'in',
60    Is: 'is',
61    IsNot: 'is not',
62    Lt: '<',
63    LtE: '<=',
64    NotEq: '!=',
65    NotIn: 'not in'
66}
67
68UNARYOP_SYMBOLS = {
69    Invert: '~',
70    Not: 'not',
71    UAdd: '+',
72    USub: '-'
73}
74
75ALL_SYMBOLS = {}
76ALL_SYMBOLS.update(BOOLOP_SYMBOLS)
77ALL_SYMBOLS.update(BINOP_SYMBOLS)
78ALL_SYMBOLS.update(CMPOP_SYMBOLS)
79ALL_SYMBOLS.update(UNARYOP_SYMBOLS)
80
81
82def parse(expr, filename='<unknown>', mode='exec'):
83    """Parse an expression into an AST node."""
84    return compile(expr, filename, mode, PyCF_ONLY_AST)
85
86
87def to_source(node, indent_with=' ' * 4):
88    """
89    This function can convert a node tree back into python sourcecode.  This
90    is useful for debugging purposes, especially if you're dealing with custom
91    asts not generated by python itself.
92
93    It could be that the sourcecode is evaluable when the AST itself is not
94    compilable / evaluable.  The reason for this is that the AST contains some
95    more data than regular sourcecode does, which is dropped during
96    conversion.
97
98    Each level of indentation is replaced with `indent_with`.  Per default this
99    parameter is equal to four spaces as suggested by PEP 8, but it might be
100    adjusted to match the application's styleguide.
101    """
102    generator = SourceGenerator(indent_with)
103    generator.visit(node)
104    return ''.join(generator.result)
105
106
107def dump(node):
108    """
109    A very verbose representation of the node passed.  This is useful for
110    debugging purposes.
111    """
112    def _format(node):
113        if isinstance(node, AST):
114            return '%s(%s)' % (node.__class__.__name__,
115                               ', '.join('%s=%s' % (a, _format(b))
116                                         for a, b in iter_fields(node)))
117        elif isinstance(node, list):
118            return '[%s]' % ', '.join(_format(x) for x in node)
119        return repr(node)
120    if not isinstance(node, AST):
121        raise TypeError('expected AST, got %r' % node.__class__.__name__)
122    return _format(node)
123
124
125def copy_location(new_node, old_node):
126    """
127    Copy the source location hint (`lineno` and `col_offset`) from the
128    old to the new node if possible and return the new one.
129    """
130    for attr in 'lineno', 'col_offset':
131        if attr in old_node._attributes and attr in new_node._attributes \
132           and hasattr(old_node, attr):
133            setattr(new_node, attr, getattr(old_node, attr))
134    return new_node
135
136
137def fix_missing_locations(node):
138    """
139    Some nodes require a line number and the column offset.  Without that
140    information the compiler will abort the compilation.  Because it can be
141    a dull task to add appropriate line numbers and column offsets when
142    adding new nodes this function can help.  It copies the line number and
143    column offset of the parent node to the child nodes without this
144    information.
145
146    Unlike `copy_location` this works recursive and won't touch nodes that
147    already have a location information.
148    """
149    def _fix(node, lineno, col_offset):
150        if 'lineno' in node._attributes:
151            if not hasattr(node, 'lineno'):
152                node.lineno = lineno
153            else:
154                lineno = node.lineno
155        if 'col_offset' in node._attributes:
156            if not hasattr(node, 'col_offset'):
157                node.col_offset = col_offset
158            else:
159                col_offset = node.col_offset
160        for child in iter_child_nodes(node):
161            _fix(child, lineno, col_offset)
162    _fix(node, 1, 0)
163    return node
164
165
166def increment_lineno(node, n=1):
167    """
168    Increment the line numbers of all nodes by `n` if they have line number
169    attributes.  This is useful to "move code" to a different location in a
170    file.
171    """
172    for node in zip((node,), walk(node)):
173        if 'lineno' in node._attributes:
174            node.lineno = getattr(node, 'lineno', 0) + n
175
176
177def iter_fields(node):
178    """Iterate over all fields of a node, only yielding existing fields."""
179    # CPython 2.5 compat
180    if not hasattr(node, '_fields') or not node._fields:
181        return
182    for field in node._fields:
183        try:
184            yield field, getattr(node, field)
185        except AttributeError:
186            pass
187
188
189def get_fields(node):
190    """Like `iter_fiels` but returns a dict."""
191    return dict(iter_fields(node))
192
193
194def iter_child_nodes(node):
195    """Iterate over all child nodes or a node."""
196    for name, field in iter_fields(node):
197        if isinstance(field, AST):
198            yield field
199        elif isinstance(field, list):
200            for item in field:
201                if isinstance(item, AST):
202                    yield item
203
204
205def get_child_nodes(node):
206    """Like `iter_child_nodes` but returns a list."""
207    return list(iter_child_nodes(node))
208
209
210def get_compile_mode(node):
211    """
212    Get the mode for `compile` of a given node.  If the node is not a `mod`
213    node (`Expression`, `Module` etc.) a `TypeError` is thrown.
214    """
215    if not isinstance(node, mod):
216        raise TypeError('expected mod node, got %r' % node.__class__.__name__)
217    return {
218        Expression: 'eval',
219        Interactive: 'single'
220    }.get(node.__class__, 'expr')
221
222
223def get_docstring(node):
224    """
225    Return the docstring for the given node or `None` if no docstring can be
226    found.  If the node provided does not accept docstrings a `TypeError`
227    will be raised.
228    """
229    if not isinstance(node, (FunctionDef, ClassDef, Module)):
230        raise TypeError("%r can't have docstrings" % node.__class__.__name__)
231    if node.body and isinstance(node.body[0], Str):
232        return node.body[0].s
233
234
235def walk(node):
236    """
237    Iterate over all nodes.  This is useful if you only want to modify nodes in
238    place and don't care about the context or the order the nodes are returned.
239    """
240    from collections import deque
241    todo = deque([node])
242    while todo:
243        node = todo.popleft()
244        todo.extend(iter_child_nodes(node))
245        yield node
246
247
248class NodeVisitor(object):
249    """
250    Walks the abstract syntax tree and call visitor functions for every node
251    found.  The visitor functions may return values which will be forwarded
252    by the `visit` method.
253
254    Per default the visitor functions for the nodes are ``'visit_'`` +
255    class name of the node.  So a `TryFinally` node visit function would
256    be `visit_TryFinally`.  This behavior can be changed by overriding
257    the `get_visitor` function.  If no visitor function exists for a node
258    (return value `None`) the `generic_visit` visitor is used instead.
259
260    Don't use the `NodeVisitor` if you want to apply changes to nodes during
261    traversing.  For this a special visitor exists (`NodeTransformer`) that
262    allows modifications.
263    """
264
265    def get_visitor(self, node):
266        """
267        Return the visitor function for this node or `None` if no visitor
268        exists for this node.  In that case the generic visit function is
269        used instead.
270        """
271        method = 'visit_' + node.__class__.__name__
272        return getattr(self, method, None)
273
274    def visit(self, node):
275        """Visit a node."""
276        f = self.get_visitor(node)
277        if f is not None:
278            return f(node)
279        return self.generic_visit(node)
280
281    def generic_visit(self, node):
282        """Called if no explicit visitor function exists for a node."""
283        for field, value in iter_fields(node):
284            if isinstance(value, list):
285                for item in value:
286                    if isinstance(item, AST):
287                        self.visit(item)
288            elif isinstance(value, AST):
289                self.visit(value)
290
291
292class NodeTransformer(NodeVisitor):
293    """
294    Walks the abstract syntax tree and allows modifications of nodes.
295
296    The `NodeTransformer` will walk the AST and use the return value of the
297    visitor functions to replace or remove the old node.  If the return
298    value of the visitor function is `None` the node will be removed
299    from the previous location otherwise it's replaced with the return
300    value.  The return value may be the original node in which case no
301    replacement takes place.
302
303    Here an example transformer that rewrites all `foo` to `data['foo']`::
304
305        class RewriteName(NodeTransformer):
306
307            def visit_Name(self, node):
308                return copy_location(Subscript(
309                    value=Name(id='data', ctx=Load()),
310                    slice=Index(value=Str(s=node.id)),
311                    ctx=node.ctx
312                ), node)
313
314    Keep in mind that if the node you're operating on has child nodes
315    you must either transform the child nodes yourself or call the generic
316    visit function for the node first.
317
318    Nodes that were part of a collection of statements (that applies to
319    all statement nodes) may also return a list of nodes rather than just
320    a single node.
321
322    Usually you use the transformer like this::
323
324        node = YourTransformer().visit(node)
325    """
326
327    def generic_visit(self, node):
328        for field, old_value in iter_fields(node):
329            old_value = getattr(node, field, None)
330            if isinstance(old_value, list):
331                new_values = []
332                for value in old_value:
333                    if isinstance(value, AST):
334                        value = self.visit(value)
335                        if value is None:
336                            continue
337                        elif not isinstance(value, AST):
338                            new_values.extend(value)
339                            continue
340                    new_values.append(value)
341                old_value[:] = new_values
342            elif isinstance(old_value, AST):
343                new_node = self.visit(old_value)
344                if new_node is None:
345                    delattr(node, field)
346                else:
347                    setattr(node, field, new_node)
348        return node
349
350
351class SourceGenerator(NodeVisitor):
352    """
353    This visitor is able to transform a well formed syntax tree into python
354    sourcecode.  For more details have a look at the docstring of the
355    `node_to_source` function.
356    """
357
358    def __init__(self, indent_with):
359        self.result = []
360        self.indent_with = indent_with
361        self.indentation = 0
362        self.new_lines = 0
363
364    def write(self, x):
365        if self.new_lines:
366            if self.result:
367                self.result.append('\n' * self.new_lines)
368            self.result.append(self.indent_with * self.indentation)
369            self.new_lines = 0
370        self.result.append(x)
371
372    def newline(self, n=1):
373        self.new_lines = max(self.new_lines, n)
374
375    def body(self, statements):
376        self.new_line = True
377        self.indentation += 1
378        for stmt in statements:
379            self.visit(stmt)
380        self.indentation -= 1
381
382    def body_or_else(self, node):
383        self.body(node.body)
384        if node.orelse:
385            self.newline()
386            self.write('else:')
387            self.body(node.orelse)
388
389    def signature(self, node):
390        want_comma = []
391        def write_comma():
392            if want_comma:
393                self.write(', ')
394            else:
395                want_comma.append(True)
396
397        padding = [None] * (len(node.args) - len(node.defaults))
398        for arg, default in zip(node.args, padding + node.defaults):
399            write_comma()
400            self.visit(arg)
401            if default is not None:
402                self.write('=')
403                self.visit(default)
404        if node.vararg is not None:
405            write_comma()
406            self.write('*' + arg_stringname(node.vararg))
407        if node.kwarg is not None:
408            write_comma()
409            self.write('**' + arg_stringname(node.kwarg))
410
411    def decorators(self, node):
412        for decorator in node.decorator_list:
413            self.newline()
414            self.write('@')
415            self.visit(decorator)
416
417    # Statements
418
419    def visit_Assign(self, node):
420        self.newline()
421        for idx, target in enumerate(node.targets):
422            if idx:
423                self.write(', ')
424            self.visit(target)
425        self.write(' = ')
426        self.visit(node.value)
427
428    def visit_AugAssign(self, node):
429        self.newline()
430        self.visit(node.target)
431        self.write(BINOP_SYMBOLS[type(node.op)] + '=')
432        self.visit(node.value)
433
434    def visit_ImportFrom(self, node):
435        self.newline()
436        self.write('from %s%s import ' % ('.' * node.level, node.module))
437        for idx, item in enumerate(node.names):
438            if idx:
439                self.write(', ')
440            self.write(item)
441
442    def visit_Import(self, node):
443        self.newline()
444        for item in node.names:
445            self.write('import ')
446            self.visit(item)
447
448    def visit_Expr(self, node):
449        self.newline()
450        self.generic_visit(node)
451
452    def visit_FunctionDef(self, node):
453        self.newline(n=2)
454        self.decorators(node)
455        self.newline()
456        self.write('def %s(' % node.name)
457        self.signature(node.args)
458        self.write('):')
459        self.body(node.body)
460
461    def visit_ClassDef(self, node):
462        have_args = []
463        def paren_or_comma():
464            if have_args:
465                self.write(', ')
466            else:
467                have_args.append(True)
468                self.write('(')
469
470        self.newline(n=3)
471        self.decorators(node)
472        self.newline()
473        self.write('class %s' % node.name)
474        for base in node.bases:
475            paren_or_comma()
476            self.visit(base)
477        # XXX: the if here is used to keep this module compatible
478        #      with python 2.6.
479        if hasattr(node, 'keywords'):
480            for keyword in node.keywords:
481                paren_or_comma()
482                self.write(keyword.arg + '=')
483                self.visit(keyword.value)
484            if node.starargs is not None:
485                paren_or_comma()
486                self.write('*')
487                self.visit(node.starargs)
488            if node.kwargs is not None:
489                paren_or_comma()
490                self.write('**')
491                self.visit(node.kwargs)
492        self.write(have_args and '):' or ':')
493        self.body(node.body)
494
495    def visit_If(self, node):
496        self.newline()
497        self.write('if ')
498        self.visit(node.test)
499        self.write(':')
500        self.body(node.body)
501        while True:
502            else_ = node.orelse
503            if len(else_) == 1 and isinstance(else_[0], If):
504                node = else_[0]
505                self.newline()
506                self.write('elif ')
507                self.visit(node.test)
508                self.write(':')
509                self.body(node.body)
510            else:
511                self.newline()
512                self.write('else:')
513                self.body(else_)
514                break
515
516    def visit_For(self, node):
517        self.newline()
518        self.write('for ')
519        self.visit(node.target)
520        self.write(' in ')
521        self.visit(node.iter)
522        self.write(':')
523        self.body_or_else(node)
524
525    def visit_While(self, node):
526        self.newline()
527        self.write('while ')
528        self.visit(node.test)
529        self.write(':')
530        self.body_or_else(node)
531
532    def visit_With(self, node):
533        self.newline()
534        self.write('with ')
535        self.visit(node.context_expr)
536        if node.optional_vars is not None:
537            self.write(' as ')
538            self.visit(node.optional_vars)
539        self.write(':')
540        self.body(node.body)
541
542    def visit_Pass(self, node):
543        self.newline()
544        self.write('pass')
545
546    def visit_Print(self, node):
547        # XXX: python 2.6 only
548        self.newline()
549        self.write('print ')
550        want_comma = False
551        if node.dest is not None:
552            self.write(' >> ')
553            self.visit(node.dest)
554            want_comma = True
555        for value in node.values:
556            if want_comma:
557                self.write(', ')
558            self.visit(value)
559            want_comma = True
560        if not node.nl:
561            self.write(',')
562
563    def visit_Delete(self, node):
564        self.newline()
565        self.write('del ')
566        for idx, target in enumerate(node):
567            if idx:
568                self.write(', ')
569            self.visit(target)
570
571    def visit_TryExcept(self, node):
572        self.newline()
573        self.write('try:')
574        self.body(node.body)
575        for handler in node.handlers:
576            self.visit(handler)
577
578    def visit_TryFinally(self, node):
579        self.newline()
580        self.write('try:')
581        self.body(node.body)
582        self.newline()
583        self.write('finally:')
584        self.body(node.finalbody)
585
586    def visit_Global(self, node):
587        self.newline()
588        self.write('global ' + ', '.join(node.names))
589
590    def visit_Nonlocal(self, node):
591        self.newline()
592        self.write('nonlocal ' + ', '.join(node.names))
593
594    def visit_Return(self, node):
595        self.newline()
596        self.write('return ')
597        self.visit(node.value)
598
599    def visit_Break(self, node):
600        self.newline()
601        self.write('break')
602
603    def visit_Continue(self, node):
604        self.newline()
605        self.write('continue')
606
607    def visit_Raise(self, node):
608        # XXX: Python 2.6 / 3.0 compatibility
609        self.newline()
610        self.write('raise')
611        if hasattr(node, 'exc') and node.exc is not None:
612            self.write(' ')
613            self.visit(node.exc)
614            if node.cause is not None:
615                self.write(' from ')
616                self.visit(node.cause)
617        elif hasattr(node, 'type') and node.type is not None:
618            self.visit(node.type)
619            if node.inst is not None:
620                self.write(', ')
621                self.visit(node.inst)
622            if node.tback is not None:
623                self.write(', ')
624                self.visit(node.tback)
625
626    # Expressions
627
628    def visit_Attribute(self, node):
629        self.visit(node.value)
630        self.write('.' + node.attr)
631
632    def visit_Call(self, node):
633        want_comma = []
634        def write_comma():
635            if want_comma:
636                self.write(', ')
637            else:
638                want_comma.append(True)
639
640        self.visit(node.func)
641        self.write('(')
642        for arg in node.args:
643            write_comma()
644            self.visit(arg)
645        for keyword in node.keywords:
646            write_comma()
647            self.write(keyword.arg + '=')
648            self.visit(keyword.value)
649        if node.starargs is not None:
650            write_comma()
651            self.write('*')
652            self.visit(node.starargs)
653        if node.kwargs is not None:
654            write_comma()
655            self.write('**')
656            self.visit(node.kwargs)
657        self.write(')')
658
659    def visit_Name(self, node):
660        self.write(node.id)
661
662    def visit_NameConstant(self, node):
663        self.write(str(node.value))
664
665    def visit_arg(self, node):
666        self.write(node.arg)
667
668    def visit_Str(self, node):
669        self.write(repr(node.s))
670
671    def visit_Bytes(self, node):
672        self.write(repr(node.s))
673
674    def visit_Num(self, node):
675        self.write(repr(node.n))
676
677    def visit_Tuple(self, node):
678        self.write('(')
679        idx = -1
680        for idx, item in enumerate(node.elts):
681            if idx:
682                self.write(', ')
683            self.visit(item)
684        self.write(idx and ')' or ',)')
685
686    def sequence_visit(left, right):
687        def visit(self, node):
688            self.write(left)
689            for idx, item in enumerate(node.elts):
690                if idx:
691                    self.write(', ')
692                self.visit(item)
693            self.write(right)
694        return visit
695
696    visit_List = sequence_visit('[', ']')
697    visit_Set = sequence_visit('{', '}')
698    del sequence_visit
699
700    def visit_Dict(self, node):
701        self.write('{')
702        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
703            if idx:
704                self.write(', ')
705            self.visit(key)
706            self.write(': ')
707            self.visit(value)
708        self.write('}')
709
710    def visit_BinOp(self, node):
711        self.write('(')
712        self.visit(node.left)
713        self.write(' %s ' % BINOP_SYMBOLS[type(node.op)])
714        self.visit(node.right)
715        self.write(')')
716
717    def visit_BoolOp(self, node):
718        self.write('(')
719        for idx, value in enumerate(node.values):
720            if idx:
721                self.write(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
722            self.visit(value)
723        self.write(')')
724
725    def visit_Compare(self, node):
726        self.write('(')
727        self.visit(node.left)
728        for op, right in zip(node.ops, node.comparators):
729            self.write(' %s ' % CMPOP_SYMBOLS[type(op)])
730            self.visit(right)
731        self.write(')')
732
733    def visit_UnaryOp(self, node):
734        self.write('(')
735        op = UNARYOP_SYMBOLS[type(node.op)]
736        self.write(op)
737        if op == 'not':
738            self.write(' ')
739        self.visit(node.operand)
740        self.write(')')
741
742    def visit_Subscript(self, node):
743        self.visit(node.value)
744        self.write('[')
745        self.visit(node.slice)
746        self.write(']')
747
748    def visit_Slice(self, node):
749        if node.lower is not None:
750            self.visit(node.lower)
751        self.write(':')
752        if node.upper is not None:
753            self.visit(node.upper)
754        if node.step is not None:
755            self.write(':')
756            if not (isinstance(node.step, Name) and node.step.id == 'None'):
757                self.visit(node.step)
758
759    def visit_ExtSlice(self, node):
760        for idx, item in node.dims:
761            if idx:
762                self.write(', ')
763            self.visit(item)
764
765    def visit_Yield(self, node):
766        self.write('yield ')
767        self.visit(node.value)
768
769    def visit_Lambda(self, node):
770        self.write('lambda ')
771        self.signature(node.args)
772        self.write(': ')
773        self.visit(node.body)
774
775    def visit_Ellipsis(self, node):
776        self.write('Ellipsis')
777
778    def generator_visit(left, right):
779        def visit(self, node):
780            self.write(left)
781            self.visit(node.elt)
782            for comprehension in node.generators:
783                self.visit(comprehension)
784            self.write(right)
785        return visit
786
787    visit_ListComp = generator_visit('[', ']')
788    visit_GeneratorExp = generator_visit('(', ')')
789    visit_SetComp = generator_visit('{', '}')
790    del generator_visit
791
792    def visit_DictComp(self, node):
793        self.write('{')
794        self.visit(node.key)
795        self.write(': ')
796        self.visit(node.value)
797        for comprehension in node.generators:
798            self.visit(comprehension)
799        self.write('}')
800
801    def visit_IfExp(self, node):
802        self.visit(node.body)
803        self.write(' if ')
804        self.visit(node.test)
805        self.write(' else ')
806        self.visit(node.orelse)
807
808    def visit_Starred(self, node):
809        self.write('*')
810        self.visit(node.value)
811
812    def visit_Repr(self, node):
813        # XXX: python 2.6 only
814        self.write('`')
815        self.visit(node.value)
816        self.write('`')
817
818    # Helper Nodes
819
820    def visit_alias(self, node):
821        self.write(node.name)
822        if node.asname is not None:
823            self.write(' as ' + node.asname)
824
825    def visit_comprehension(self, node):
826        self.write(' for ')
827        self.visit(node.target)
828        self.write(' in ')
829        self.visit(node.iter)
830        if node.ifs:
831            for if_ in node.ifs:
832                self.write(' if ')
833                self.visit(if_)
834
835    def visit_excepthandler(self, node):
836        self.newline()
837        self.write('except')
838        if node.type is not None:
839            self.write(' ')
840            self.visit(node.type)
841            if node.name is not None:
842                self.write(' as ')
843                self.visit(node.name)
844        self.write(':')
845        self.body(node.body)
846