1"""Module symbol-table generator"""
2
3from compiler import ast
4from compiler.consts import SC_LOCAL, SC_GLOBAL_IMPLICIT, SC_GLOBAL_EXPLICIT, \
5    SC_FREE, SC_CELL, SC_UNKNOWN
6from compiler.misc import mangle
7import types
8
9
10import sys
11
12MANGLE_LEN = 256
13
14class Scope:
15    # XXX how much information do I need about each name?
16    def __init__(self, name, module, klass=None):
17        self.name = name
18        self.module = module
19        self.defs = {}
20        self.uses = {}
21        self.globals = {}
22        self.params = {}
23        self.frees = {}
24        self.cells = {}
25        self.children = []
26        # nested is true if the class could contain free variables,
27        # i.e. if it is nested within another function.
28        self.nested = None
29        self.generator = None
30        self.klass = None
31        if klass is not None:
32            for i in range(len(klass)):
33                if klass[i] != '_':
34                    self.klass = klass[i:]
35                    break
36
37    def __repr__(self):
38        return "<%s: %s>" % (self.__class__.__name__, self.name)
39
40    def mangle(self, name):
41        if self.klass is None:
42            return name
43        return mangle(name, self.klass)
44
45    def add_def(self, name):
46        self.defs[self.mangle(name)] = 1
47
48    def add_use(self, name):
49        self.uses[self.mangle(name)] = 1
50
51    def add_global(self, name):
52        name = self.mangle(name)
53        if name in self.uses or name in self.defs:
54            pass # XXX warn about global following def/use
55        if name in self.params:
56            raise SyntaxError, "%s in %s is global and parameter" % \
57                  (name, self.name)
58        self.globals[name] = 1
59        self.module.add_def(name)
60
61    def add_param(self, name):
62        name = self.mangle(name)
63        self.defs[name] = 1
64        self.params[name] = 1
65
66    def get_names(self):
67        d = {}
68        d.update(self.defs)
69        d.update(self.uses)
70        d.update(self.globals)
71        return d.keys()
72
73    def add_child(self, child):
74        self.children.append(child)
75
76    def get_children(self):
77        return self.children
78
79    def DEBUG(self):
80        print >> sys.stderr, self.name, self.nested and "nested" or ""
81        print >> sys.stderr, "\tglobals: ", self.globals
82        print >> sys.stderr, "\tcells: ", self.cells
83        print >> sys.stderr, "\tdefs: ", self.defs
84        print >> sys.stderr, "\tuses: ", self.uses
85        print >> sys.stderr, "\tfrees:", self.frees
86
87    def check_name(self, name):
88        """Return scope of name.
89
90        The scope of a name could be LOCAL, GLOBAL, FREE, or CELL.
91        """
92        if name in self.globals:
93            return SC_GLOBAL_EXPLICIT
94        if name in self.cells:
95            return SC_CELL
96        if name in self.defs:
97            return SC_LOCAL
98        if self.nested and (name in self.frees or name in self.uses):
99            return SC_FREE
100        if self.nested:
101            return SC_UNKNOWN
102        else:
103            return SC_GLOBAL_IMPLICIT
104
105    def get_free_vars(self):
106        if not self.nested:
107            return ()
108        free = {}
109        free.update(self.frees)
110        for name in self.uses.keys():
111            if name not in self.defs and name not in self.globals:
112                free[name] = 1
113        return free.keys()
114
115    def handle_children(self):
116        for child in self.children:
117            frees = child.get_free_vars()
118            globals = self.add_frees(frees)
119            for name in globals:
120                child.force_global(name)
121
122    def force_global(self, name):
123        """Force name to be global in scope.
124
125        Some child of the current node had a free reference to name.
126        When the child was processed, it was labelled a free
127        variable.  Now that all its enclosing scope have been
128        processed, the name is known to be a global or builtin.  So
129        walk back down the child chain and set the name to be global
130        rather than free.
131
132        Be careful to stop if a child does not think the name is
133        free.
134        """
135        self.globals[name] = 1
136        if name in self.frees:
137            del self.frees[name]
138        for child in self.children:
139            if child.check_name(name) == SC_FREE:
140                child.force_global(name)
141
142    def add_frees(self, names):
143        """Process list of free vars from nested scope.
144
145        Returns a list of names that are either 1) declared global in the
146        parent or 2) undefined in a top-level parent.  In either case,
147        the nested scope should treat them as globals.
148        """
149        child_globals = []
150        for name in names:
151            sc = self.check_name(name)
152            if self.nested:
153                if sc == SC_UNKNOWN or sc == SC_FREE \
154                   or isinstance(self, ClassScope):
155                    self.frees[name] = 1
156                elif sc == SC_GLOBAL_IMPLICIT:
157                    child_globals.append(name)
158                elif isinstance(self, FunctionScope) and sc == SC_LOCAL:
159                    self.cells[name] = 1
160                elif sc != SC_CELL:
161                    child_globals.append(name)
162            else:
163                if sc == SC_LOCAL:
164                    self.cells[name] = 1
165                elif sc != SC_CELL:
166                    child_globals.append(name)
167        return child_globals
168
169    def get_cell_vars(self):
170        return self.cells.keys()
171
172class ModuleScope(Scope):
173    __super_init = Scope.__init__
174
175    def __init__(self):
176        self.__super_init("global", self)
177
178class FunctionScope(Scope):
179    pass
180
181class GenExprScope(Scope):
182    __super_init = Scope.__init__
183
184    __counter = 1
185
186    def __init__(self, module, klass=None):
187        i = self.__counter
188        self.__counter += 1
189        self.__super_init("generator expression<%d>"%i, module, klass)
190        self.add_param('.0')
191
192    def get_names(self):
193        keys = Scope.get_names(self)
194        return keys
195
196class LambdaScope(FunctionScope):
197    __super_init = Scope.__init__
198
199    __counter = 1
200
201    def __init__(self, module, klass=None):
202        i = self.__counter
203        self.__counter += 1
204        self.__super_init("lambda.%d" % i, module, klass)
205
206class ClassScope(Scope):
207    __super_init = Scope.__init__
208
209    def __init__(self, name, module):
210        self.__super_init(name, module, name)
211
212class SymbolVisitor:
213    def __init__(self):
214        self.scopes = {}
215        self.klass = None
216
217    # node that define new scopes
218
219    def visitModule(self, node):
220        scope = self.module = self.scopes[node] = ModuleScope()
221        self.visit(node.node, scope)
222
223    visitExpression = visitModule
224
225    def visitFunction(self, node, parent):
226        if node.decorators:
227            self.visit(node.decorators, parent)
228        parent.add_def(node.name)
229        for n in node.defaults:
230            self.visit(n, parent)
231        scope = FunctionScope(node.name, self.module, self.klass)
232        if parent.nested or isinstance(parent, FunctionScope):
233            scope.nested = 1
234        self.scopes[node] = scope
235        self._do_args(scope, node.argnames)
236        self.visit(node.code, scope)
237        self.handle_free_vars(scope, parent)
238
239    def visitGenExpr(self, node, parent):
240        scope = GenExprScope(self.module, self.klass);
241        if parent.nested or isinstance(parent, FunctionScope) \
242                or isinstance(parent, GenExprScope):
243            scope.nested = 1
244
245        self.scopes[node] = scope
246        self.visit(node.code, scope)
247
248        self.handle_free_vars(scope, parent)
249
250    def visitGenExprInner(self, node, scope):
251        for genfor in node.quals:
252            self.visit(genfor, scope)
253
254        self.visit(node.expr, scope)
255
256    def visitGenExprFor(self, node, scope):
257        self.visit(node.assign, scope, 1)
258        self.visit(node.iter, scope)
259        for if_ in node.ifs:
260            self.visit(if_, scope)
261
262    def visitGenExprIf(self, node, scope):
263        self.visit(node.test, scope)
264
265    def visitLambda(self, node, parent, assign=0):
266        # Lambda is an expression, so it could appear in an expression
267        # context where assign is passed.  The transformer should catch
268        # any code that has a lambda on the left-hand side.
269        assert not assign
270
271        for n in node.defaults:
272            self.visit(n, parent)
273        scope = LambdaScope(self.module, self.klass)
274        if parent.nested or isinstance(parent, FunctionScope):
275            scope.nested = 1
276        self.scopes[node] = scope
277        self._do_args(scope, node.argnames)
278        self.visit(node.code, scope)
279        self.handle_free_vars(scope, parent)
280
281    def _do_args(self, scope, args):
282        for name in args:
283            if type(name) == types.TupleType:
284                self._do_args(scope, name)
285            else:
286                scope.add_param(name)
287
288    def handle_free_vars(self, scope, parent):
289        parent.add_child(scope)
290        scope.handle_children()
291
292    def visitClass(self, node, parent):
293        parent.add_def(node.name)
294        for n in node.bases:
295            self.visit(n, parent)
296        scope = ClassScope(node.name, self.module)
297        if parent.nested or isinstance(parent, FunctionScope):
298            scope.nested = 1
299        if node.doc is not None:
300            scope.add_def('__doc__')
301        scope.add_def('__module__')
302        self.scopes[node] = scope
303        prev = self.klass
304        self.klass = node.name
305        self.visit(node.code, scope)
306        self.klass = prev
307        self.handle_free_vars(scope, parent)
308
309    # name can be a def or a use
310
311    # XXX a few calls and nodes expect a third "assign" arg that is
312    # true if the name is being used as an assignment.  only
313    # expressions contained within statements may have the assign arg.
314
315    def visitName(self, node, scope, assign=0):
316        if assign:
317            scope.add_def(node.name)
318        else:
319            scope.add_use(node.name)
320
321    # operations that bind new names
322
323    def visitFor(self, node, scope):
324        self.visit(node.assign, scope, 1)
325        self.visit(node.list, scope)
326        self.visit(node.body, scope)
327        if node.else_:
328            self.visit(node.else_, scope)
329
330    def visitFrom(self, node, scope):
331        for name, asname in node.names:
332            if name == "*":
333                continue
334            scope.add_def(asname or name)
335
336    def visitImport(self, node, scope):
337        for name, asname in node.names:
338            i = name.find(".")
339            if i > -1:
340                name = name[:i]
341            scope.add_def(asname or name)
342
343    def visitGlobal(self, node, scope):
344        for name in node.names:
345            scope.add_global(name)
346
347    def visitAssign(self, node, scope):
348        """Propagate assignment flag down to child nodes.
349
350        The Assign node doesn't itself contains the variables being
351        assigned to.  Instead, the children in node.nodes are visited
352        with the assign flag set to true.  When the names occur in
353        those nodes, they are marked as defs.
354
355        Some names that occur in an assignment target are not bound by
356        the assignment, e.g. a name occurring inside a slice.  The
357        visitor handles these nodes specially; they do not propagate
358        the assign flag to their children.
359        """
360        for n in node.nodes:
361            self.visit(n, scope, 1)
362        self.visit(node.expr, scope)
363
364    def visitAssName(self, node, scope, assign=1):
365        scope.add_def(node.name)
366
367    def visitAssAttr(self, node, scope, assign=0):
368        self.visit(node.expr, scope, 0)
369
370    def visitSubscript(self, node, scope, assign=0):
371        self.visit(node.expr, scope, 0)
372        for n in node.subs:
373            self.visit(n, scope, 0)
374
375    def visitSlice(self, node, scope, assign=0):
376        self.visit(node.expr, scope, 0)
377        if node.lower:
378            self.visit(node.lower, scope, 0)
379        if node.upper:
380            self.visit(node.upper, scope, 0)
381
382    def visitAugAssign(self, node, scope):
383        # If the LHS is a name, then this counts as assignment.
384        # Otherwise, it's just use.
385        self.visit(node.node, scope)
386        if isinstance(node.node, ast.Name):
387            self.visit(node.node, scope, 1) # XXX worry about this
388        self.visit(node.expr, scope)
389
390    # prune if statements if tests are false
391
392    _const_types = types.StringType, types.IntType, types.FloatType
393
394    def visitIf(self, node, scope):
395        for test, body in node.tests:
396            if isinstance(test, ast.Const):
397                if type(test.value) in self._const_types:
398                    if not test.value:
399                        continue
400            self.visit(test, scope)
401            self.visit(body, scope)
402        if node.else_:
403            self.visit(node.else_, scope)
404
405    # a yield statement signals a generator
406
407    def visitYield(self, node, scope):
408        scope.generator = 1
409        self.visit(node.value, scope)
410
411def list_eq(l1, l2):
412    return sorted(l1) == sorted(l2)
413
414if __name__ == "__main__":
415    import sys
416    from compiler import parseFile, walk
417    import symtable
418
419    def get_names(syms):
420        return [s for s in [s.get_name() for s in syms.get_symbols()]
421                if not (s.startswith('_[') or s.startswith('.'))]
422
423    for file in sys.argv[1:]:
424        print file
425        f = open(file)
426        buf = f.read()
427        f.close()
428        syms = symtable.symtable(buf, file, "exec")
429        mod_names = get_names(syms)
430        tree = parseFile(file)
431        s = SymbolVisitor()
432        walk(tree, s)
433
434        # compare module-level symbols
435        names2 = s.scopes[tree].get_names()
436
437        if not list_eq(mod_names, names2):
438            print
439            print "oops", file
440            print sorted(mod_names)
441            print sorted(names2)
442            sys.exit(-1)
443
444        d = {}
445        d.update(s.scopes)
446        del d[tree]
447        scopes = d.values()
448        del d
449
450        for s in syms.get_symbols():
451            if s.is_namespace():
452                l = [sc for sc in scopes
453                     if sc.name == s.get_name()]
454                if len(l) > 1:
455                    print "skipping", s.get_name()
456                else:
457                    if not list_eq(get_names(s.get_namespace()),
458                                   l[0].get_names()):
459                        print s.get_name()
460                        print sorted(get_names(s.get_namespace()))
461                        print sorted(l[0].get_names())
462                        sys.exit(-1)
463