symbols.py revision 354433a59dd1701985ec8463ced25a967cade3c1
1"""Module symbol-table generator"""
2
3from compiler import ast
4from compiler.consts import SC_LOCAL, SC_GLOBAL, SC_FREE, SC_CELL, SC_UNKNOWN
5from compiler.misc import mangle
6import types
7
8
9import sys
10
11MANGLE_LEN = 256
12
13class Scope:
14    # XXX how much information do I need about each name?
15    def __init__(self, name, module, klass=None):
16        self.name = name
17        self.module = module
18        self.defs = {}
19        self.uses = {}
20        self.globals = {}
21        self.params = {}
22        self.frees = {}
23        self.cells = {}
24        self.children = []
25        # nested is true if the class could contain free variables,
26        # i.e. if it is nested within another function.
27        self.nested = None
28        self.generator = None
29        self.klass = None
30        if klass is not None:
31            for i in range(len(klass)):
32                if klass[i] != '_':
33                    self.klass = klass[i:]
34                    break
35
36    def __repr__(self):
37        return "<%s: %s>" % (self.__class__.__name__, self.name)
38
39    def mangle(self, name):
40        if self.klass is None:
41            return name
42        return mangle(name, self.klass)
43
44    def add_def(self, name):
45        self.defs[self.mangle(name)] = 1
46
47    def add_use(self, name):
48        self.uses[self.mangle(name)] = 1
49
50    def add_global(self, name):
51        name = self.mangle(name)
52        if self.uses.has_key(name) or self.defs.has_key(name):
53            pass # XXX warn about global following def/use
54        if self.params.has_key(name):
55            raise SyntaxError, "%s in %s is global and parameter" % \
56                  (name, self.name)
57        self.globals[name] = 1
58        self.module.add_def(name)
59
60    def add_param(self, name):
61        name = self.mangle(name)
62        self.defs[name] = 1
63        self.params[name] = 1
64
65    def get_names(self):
66        d = {}
67        d.update(self.defs)
68        d.update(self.uses)
69        d.update(self.globals)
70        return d.keys()
71
72    def add_child(self, child):
73        self.children.append(child)
74
75    def get_children(self):
76        return self.children
77
78    def DEBUG(self):
79        print >> sys.stderr, self.name, self.nested and "nested" or ""
80        print >> sys.stderr, "\tglobals: ", self.globals
81        print >> sys.stderr, "\tcells: ", self.cells
82        print >> sys.stderr, "\tdefs: ", self.defs
83        print >> sys.stderr, "\tuses: ", self.uses
84        print >> sys.stderr, "\tfrees:", self.frees
85
86    def check_name(self, name):
87        """Return scope of name.
88
89        The scope of a name could be LOCAL, GLOBAL, FREE, or CELL.
90        """
91        if self.globals.has_key(name):
92            return SC_GLOBAL
93        if self.cells.has_key(name):
94            return SC_CELL
95        if self.defs.has_key(name):
96            return SC_LOCAL
97        if self.nested and (self.frees.has_key(name) or
98                            self.uses.has_key(name)):
99            return SC_FREE
100        if self.nested:
101            return SC_UNKNOWN
102        else:
103            return SC_GLOBAL
104
105    def get_free_vars(self):
106        if not self.nested:
107            return ()
108        free = {}
109        free.update(self.frees)
110        for name in self.uses.keys():
111            if not (self.defs.has_key(name) or
112                    self.globals.has_key(name)):
113                free[name] = 1
114        return free.keys()
115
116    def handle_children(self):
117        for child in self.children:
118            frees = child.get_free_vars()
119            globals = self.add_frees(frees)
120            for name in globals:
121                child.force_global(name)
122
123    def force_global(self, name):
124        """Force name to be global in scope.
125
126        Some child of the current node had a free reference to name.
127        When the child was processed, it was labelled a free
128        variable.  Now that all its enclosing scope have been
129        processed, the name is known to be a global or builtin.  So
130        walk back down the child chain and set the name to be global
131        rather than free.
132
133        Be careful to stop if a child does not think the name is
134        free.
135        """
136        self.globals[name] = 1
137        if self.frees.has_key(name):
138            del self.frees[name]
139        for child in self.children:
140            if child.check_name(name) == SC_FREE:
141                child.force_global(name)
142
143    def add_frees(self, names):
144        """Process list of free vars from nested scope.
145
146        Returns a list of names that are either 1) declared global in the
147        parent or 2) undefined in a top-level parent.  In either case,
148        the nested scope should treat them as globals.
149        """
150        child_globals = []
151        for name in names:
152            sc = self.check_name(name)
153            if self.nested:
154                if sc == SC_UNKNOWN or sc == SC_FREE \
155                   or isinstance(self, ClassScope):
156                    self.frees[name] = 1
157                elif sc == SC_GLOBAL:
158                    child_globals.append(name)
159                elif isinstance(self, FunctionScope) and sc == SC_LOCAL:
160                    self.cells[name] = 1
161                elif sc != SC_CELL:
162                    child_globals.append(name)
163            else:
164                if sc == SC_LOCAL:
165                    self.cells[name] = 1
166                elif sc != SC_CELL:
167                    child_globals.append(name)
168        return child_globals
169
170    def get_cell_vars(self):
171        return self.cells.keys()
172
173class ModuleScope(Scope):
174    __super_init = Scope.__init__
175
176    def __init__(self):
177        self.__super_init("global", self)
178
179class FunctionScope(Scope):
180    pass
181
182class GenExprScope(Scope):
183    __super_init = Scope.__init__
184
185    __counter = 1
186
187    def __init__(self, module, klass=None):
188        i = self.__counter
189        self.__counter += 1
190        self.__super_init("generator expression<%d>"%i, module, klass)
191        self.add_param('[outmost-iterable]')
192
193    def get_names(self):
194        keys = Scope.get_names()
195        return keys
196
197class LambdaScope(FunctionScope):
198    __super_init = Scope.__init__
199
200    __counter = 1
201
202    def __init__(self, module, klass=None):
203        i = self.__counter
204        self.__counter += 1
205        self.__super_init("lambda.%d" % i, module, klass)
206
207class ClassScope(Scope):
208    __super_init = Scope.__init__
209
210    def __init__(self, name, module):
211        self.__super_init(name, module, name)
212
213class SymbolVisitor:
214    def __init__(self):
215        self.scopes = {}
216        self.klass = None
217
218    # node that define new scopes
219
220    def visitModule(self, node):
221        scope = self.module = self.scopes[node] = ModuleScope()
222        self.visit(node.node, scope)
223
224    visitExpression = visitModule
225
226    def visitFunction(self, node, parent):
227        parent.add_def(node.name)
228        for n in node.defaults:
229            self.visit(n, parent)
230        scope = FunctionScope(node.name, self.module, self.klass)
231        if parent.nested or isinstance(parent, FunctionScope):
232            scope.nested = 1
233        self.scopes[node] = scope
234        self._do_args(scope, node.argnames)
235        self.visit(node.code, scope)
236        self.handle_free_vars(scope, parent)
237
238    def visitGenExpr(self, node, parent):
239        scope = GenExprScope(self.module, self.klass);
240        if parent.nested or isinstance(parent, FunctionScope) \
241                or isinstance(parent, GenExprScope):
242            scope.nested = 1
243
244        self.scopes[node] = scope
245        self.visit(node.code, scope)
246
247        self.handle_free_vars(scope, parent)
248
249    def visitGenExprInner(self, node, scope):
250        for genfor in node.quals:
251            self.visit(genfor, scope)
252
253        self.visit(node.expr, scope)
254
255    def visitGenExprFor(self, node, scope):
256        self.visit(node.assign, scope, 1)
257        self.visit(node.iter, scope)
258        for if_ in node.ifs:
259            self.visit(if_, scope)
260
261    def visitGenExprIf(self, node, scope):
262        self.visit(node.test, scope)
263
264    def visitLambda(self, node, parent, assign=0):
265        # Lambda is an expression, so it could appear in an expression
266        # context where assign is passed.  The transformer should catch
267        # any code that has a lambda on the left-hand side.
268        assert not assign
269
270        for n in node.defaults:
271            self.visit(n, parent)
272        scope = LambdaScope(self.module, self.klass)
273        if parent.nested or isinstance(parent, FunctionScope):
274            scope.nested = 1
275        self.scopes[node] = scope
276        self._do_args(scope, node.argnames)
277        self.visit(node.code, scope)
278        self.handle_free_vars(scope, parent)
279
280    def _do_args(self, scope, args):
281        for name in args:
282            if type(name) == types.TupleType:
283                self._do_args(scope, name)
284            else:
285                scope.add_param(name)
286
287    def handle_free_vars(self, scope, parent):
288        parent.add_child(scope)
289        scope.handle_children()
290
291    def visitClass(self, node, parent):
292        parent.add_def(node.name)
293        for n in node.bases:
294            self.visit(n, parent)
295        scope = ClassScope(node.name, self.module)
296        if parent.nested or isinstance(parent, FunctionScope):
297            scope.nested = 1
298        if node.doc is not None:
299            scope.add_def('__doc__')
300        scope.add_def('__module__')
301        self.scopes[node] = scope
302        prev = self.klass
303        self.klass = node.name
304        self.visit(node.code, scope)
305        self.klass = prev
306        self.handle_free_vars(scope, parent)
307
308    # name can be a def or a use
309
310    # XXX a few calls and nodes expect a third "assign" arg that is
311    # true if the name is being used as an assignment.  only
312    # expressions contained within statements may have the assign arg.
313
314    def visitName(self, node, scope, assign=0):
315        if assign:
316            scope.add_def(node.name)
317        else:
318            scope.add_use(node.name)
319
320    # operations that bind new names
321
322    def visitFor(self, node, scope):
323        self.visit(node.assign, scope, 1)
324        self.visit(node.list, scope)
325        self.visit(node.body, scope)
326        if node.else_:
327            self.visit(node.else_, scope)
328
329    def visitFrom(self, node, scope):
330        for name, asname in node.names:
331            if name == "*":
332                continue
333            scope.add_def(asname or name)
334
335    def visitImport(self, node, scope):
336        for name, asname in node.names:
337            i = name.find(".")
338            if i > -1:
339                name = name[:i]
340            scope.add_def(asname or name)
341
342    def visitGlobal(self, node, scope):
343        for name in node.names:
344            scope.add_global(name)
345
346    def visitAssign(self, node, scope):
347        """Propagate assignment flag down to child nodes.
348
349        The Assign node doesn't itself contains the variables being
350        assigned to.  Instead, the children in node.nodes are visited
351        with the assign flag set to true.  When the names occur in
352        those nodes, they are marked as defs.
353
354        Some names that occur in an assignment target are not bound by
355        the assignment, e.g. a name occurring inside a slice.  The
356        visitor handles these nodes specially; they do not propagate
357        the assign flag to their children.
358        """
359        for n in node.nodes:
360            self.visit(n, scope, 1)
361        self.visit(node.expr, scope)
362
363    def visitAssName(self, node, scope, assign=1):
364        scope.add_def(node.name)
365
366    def visitAssAttr(self, node, scope, assign=0):
367        self.visit(node.expr, scope, 0)
368
369    def visitSubscript(self, node, scope, assign=0):
370        self.visit(node.expr, scope, 0)
371        for n in node.subs:
372            self.visit(n, scope, 0)
373
374    def visitSlice(self, node, scope, assign=0):
375        self.visit(node.expr, scope, 0)
376        if node.lower:
377            self.visit(node.lower, scope, 0)
378        if node.upper:
379            self.visit(node.upper, scope, 0)
380
381    def visitAugAssign(self, node, scope):
382        # If the LHS is a name, then this counts as assignment.
383        # Otherwise, it's just use.
384        self.visit(node.node, scope)
385        if isinstance(node.node, ast.Name):
386            self.visit(node.node, scope, 1) # XXX worry about this
387        self.visit(node.expr, scope)
388
389    # prune if statements if tests are false
390
391    _const_types = types.StringType, types.IntType, types.FloatType
392
393    def visitIf(self, node, scope):
394        for test, body in node.tests:
395            if isinstance(test, ast.Const):
396                if type(test.value) in self._const_types:
397                    if not test.value:
398                        continue
399            self.visit(test, scope)
400            self.visit(body, scope)
401        if node.else_:
402            self.visit(node.else_, scope)
403
404    # a yield statement signals a generator
405
406    def visitYield(self, node, scope):
407        scope.generator = 1
408        self.visit(node.value, scope)
409
410def sort(l):
411    l = l[:]
412    l.sort()
413    return l
414
415def list_eq(l1, l2):
416    return sort(l1) == sort(l2)
417
418if __name__ == "__main__":
419    import sys
420    from compiler import parseFile, walk
421    import symtable
422
423    def get_names(syms):
424        return [s for s in [s.get_name() for s in syms.get_symbols()]
425                if not (s.startswith('_[') or s.startswith('.'))]
426
427    for file in sys.argv[1:]:
428        print file
429        f = open(file)
430        buf = f.read()
431        f.close()
432        syms = symtable.symtable(buf, file, "exec")
433        mod_names = get_names(syms)
434        tree = parseFile(file)
435        s = SymbolVisitor()
436        walk(tree, s)
437
438        # compare module-level symbols
439        names2 = s.scopes[tree].get_names()
440
441        if not list_eq(mod_names, names2):
442            print
443            print "oops", file
444            print sort(mod_names)
445            print sort(names2)
446            sys.exit(-1)
447
448        d = {}
449        d.update(s.scopes)
450        del d[tree]
451        scopes = d.values()
452        del d
453
454        for s in syms.get_symbols():
455            if s.is_namespace():
456                l = [sc for sc in scopes
457                     if sc.name == s.get_name()]
458                if len(l) > 1:
459                    print "skipping", s.get_name()
460                else:
461                    if not list_eq(get_names(s.get_namespace()),
462                                   l[0].get_names()):
463                        print s.get_name()
464                        print sort(get_names(s.get_namespace()))
465                        print sort(l[0].get_names())
466                        sys.exit(-1)
467