1# mako/pyparser.py
2# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
3#
4# This module is part of Mako and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""Handles parsing of Python code.
8
9Parsing to AST is done via _ast on Python > 2.5, otherwise the compiler
10module is used.
11"""
12
13from mako import exceptions, util, compat
14from mako.compat import arg_stringname
15import operator
16
17if compat.py3k:
18    # words that cannot be assigned to (notably
19    # smaller than the total keys in __builtins__)
20    reserved = set(['True', 'False', 'None', 'print'])
21
22    # the "id" attribute on a function node
23    arg_id = operator.attrgetter('arg')
24else:
25    # words that cannot be assigned to (notably
26    # smaller than the total keys in __builtins__)
27    reserved = set(['True', 'False', 'None'])
28
29    # the "id" attribute on a function node
30    arg_id = operator.attrgetter('id')
31
32import _ast
33util.restore__ast(_ast)
34from mako import _ast_util
35
36
37def parse(code, mode='exec', **exception_kwargs):
38    """Parse an expression into AST"""
39
40    try:
41        return _ast_util.parse(code, '<unknown>', mode)
42    except Exception:
43        raise exceptions.SyntaxException(
44                    "(%s) %s (%r)" % (
45                        compat.exception_as().__class__.__name__,
46                        compat.exception_as(),
47                        code[0:50]
48                    ), **exception_kwargs)
49
50
51class FindIdentifiers(_ast_util.NodeVisitor):
52
53    def __init__(self, listener, **exception_kwargs):
54        self.in_function = False
55        self.in_assign_targets = False
56        self.local_ident_stack = set()
57        self.listener = listener
58        self.exception_kwargs = exception_kwargs
59
60    def _add_declared(self, name):
61        if not self.in_function:
62            self.listener.declared_identifiers.add(name)
63        else:
64            self.local_ident_stack.add(name)
65
66    def visit_ClassDef(self, node):
67        self._add_declared(node.name)
68
69    def visit_Assign(self, node):
70
71        # flip around the visiting of Assign so the expression gets
72        # evaluated first, in the case of a clause like "x=x+5" (x
73        # is undeclared)
74
75        self.visit(node.value)
76        in_a = self.in_assign_targets
77        self.in_assign_targets = True
78        for n in node.targets:
79            self.visit(n)
80        self.in_assign_targets = in_a
81
82    if compat.py3k:
83
84        # ExceptHandler is in Python 2, but this block only works in
85        # Python 3 (and is required there)
86
87        def visit_ExceptHandler(self, node):
88            if node.name is not None:
89                self._add_declared(node.name)
90            if node.type is not None:
91                self.visit(node.type)
92            for statement in node.body:
93                self.visit(statement)
94
95    def visit_Lambda(self, node, *args):
96        self._visit_function(node, True)
97
98    def visit_FunctionDef(self, node):
99        self._add_declared(node.name)
100        self._visit_function(node, False)
101
102    def _expand_tuples(self, args):
103        for arg in args:
104            if isinstance(arg, _ast.Tuple):
105                for n in arg.elts:
106                    yield n
107            else:
108                yield arg
109
110    def _visit_function(self, node, islambda):
111
112        # push function state onto stack.  dont log any more
113        # identifiers as "declared" until outside of the function,
114        # but keep logging identifiers as "undeclared". track
115        # argument names in each function header so they arent
116        # counted as "undeclared"
117
118        inf = self.in_function
119        self.in_function = True
120
121        local_ident_stack = self.local_ident_stack
122        self.local_ident_stack = local_ident_stack.union([
123            arg_id(arg) for arg in self._expand_tuples(node.args.args)
124        ])
125        if islambda:
126            self.visit(node.body)
127        else:
128            for n in node.body:
129                self.visit(n)
130        self.in_function = inf
131        self.local_ident_stack = local_ident_stack
132
133    def visit_For(self, node):
134
135        # flip around visit
136
137        self.visit(node.iter)
138        self.visit(node.target)
139        for statement in node.body:
140            self.visit(statement)
141        for statement in node.orelse:
142            self.visit(statement)
143
144    def visit_Name(self, node):
145        if isinstance(node.ctx, _ast.Store):
146            # this is eqiuvalent to visit_AssName in
147            # compiler
148            self._add_declared(node.id)
149        elif node.id not in reserved and node.id \
150            not in self.listener.declared_identifiers and node.id \
151                not in self.local_ident_stack:
152            self.listener.undeclared_identifiers.add(node.id)
153
154    def visit_Import(self, node):
155        for name in node.names:
156            if name.asname is not None:
157                self._add_declared(name.asname)
158            else:
159                self._add_declared(name.name.split('.')[0])
160
161    def visit_ImportFrom(self, node):
162        for name in node.names:
163            if name.asname is not None:
164                self._add_declared(name.asname)
165            else:
166                if name.name == '*':
167                    raise exceptions.CompileException(
168                        "'import *' is not supported, since all identifier "
169                        "names must be explicitly declared.  Please use the "
170                        "form 'from <modulename> import <name1>, <name2>, "
171                        "...' instead.", **self.exception_kwargs)
172                self._add_declared(name.name)
173
174
175class FindTuple(_ast_util.NodeVisitor):
176
177    def __init__(self, listener, code_factory, **exception_kwargs):
178        self.listener = listener
179        self.exception_kwargs = exception_kwargs
180        self.code_factory = code_factory
181
182    def visit_Tuple(self, node):
183        for n in node.elts:
184            p = self.code_factory(n, **self.exception_kwargs)
185            self.listener.codeargs.append(p)
186            self.listener.args.append(ExpressionGenerator(n).value())
187            self.listener.declared_identifiers = \
188                self.listener.declared_identifiers.union(
189                                                p.declared_identifiers)
190            self.listener.undeclared_identifiers = \
191                self.listener.undeclared_identifiers.union(
192                                                p.undeclared_identifiers)
193
194
195class ParseFunc(_ast_util.NodeVisitor):
196
197    def __init__(self, listener, **exception_kwargs):
198        self.listener = listener
199        self.exception_kwargs = exception_kwargs
200
201    def visit_FunctionDef(self, node):
202        self.listener.funcname = node.name
203
204        argnames = [arg_id(arg) for arg in node.args.args]
205        if node.args.vararg:
206            argnames.append(arg_stringname(node.args.vararg))
207
208        if compat.py2k:
209            # kw-only args don't exist in Python 2
210            kwargnames = []
211        else:
212            kwargnames = [arg_id(arg) for arg in node.args.kwonlyargs]
213        if node.args.kwarg:
214            kwargnames.append(arg_stringname(node.args.kwarg))
215        self.listener.argnames = argnames
216        self.listener.defaults = node.args.defaults  # ast
217        self.listener.kwargnames = kwargnames
218        if compat.py2k:
219            self.listener.kwdefaults = []
220        else:
221            self.listener.kwdefaults = node.args.kw_defaults
222        self.listener.varargs = node.args.vararg
223        self.listener.kwargs = node.args.kwarg
224
225class ExpressionGenerator(object):
226
227    def __init__(self, astnode):
228        self.generator = _ast_util.SourceGenerator(' ' * 4)
229        self.generator.visit(astnode)
230
231    def value(self):
232        return ''.join(self.generator.result)
233