symbols.py revision d4be10dc2c8454779e1f84cf9d1ab04daf310719
1"""Module symbol-table generator"""
2
3from compiler import ast
4from compiler.consts import SC_LOCAL, SC_GLOBAL, SC_FREE, SC_CELL, SC_UNKNOWN
5from compiler.misc import mangle
6import types
7
8
9import sys
10
11MANGLE_LEN = 256
12
13class Scope:
14    # XXX how much information do I need about each name?
15    def __init__(self, name, module, klass=None):
16        self.name = name
17        self.module = module
18        self.defs = {}
19        self.uses = {}
20        self.globals = {}
21        self.params = {}
22        self.frees = {}
23        self.cells = {}
24        self.children = []
25        # nested is true if the class could contain free variables,
26        # i.e. if it is nested within another function.
27        self.nested = None
28        self.generator = None
29        self.klass = None
30        if klass is not None:
31            for i in range(len(klass)):
32                if klass[i] != '_':
33                    self.klass = klass[i:]
34                    break
35
36    def __repr__(self):
37        return "<%s: %s>" % (self.__class__.__name__, self.name)
38
39    def mangle(self, name):
40        if self.klass is None:
41            return name
42        return mangle(name, self.klass)
43
44    def add_def(self, name):
45        self.defs[self.mangle(name)] = 1
46
47    def add_use(self, name):
48        self.uses[self.mangle(name)] = 1
49
50    def add_global(self, name):
51        name = self.mangle(name)
52        if self.uses.has_key(name) or self.defs.has_key(name):
53            pass # XXX warn about global following def/use
54        if self.params.has_key(name):
55            raise SyntaxError, "%s in %s is global and parameter" % \
56                  (name, self.name)
57        self.globals[name] = 1
58        self.module.add_def(name)
59
60    def add_param(self, name):
61        name = self.mangle(name)
62        self.defs[name] = 1
63        self.params[name] = 1
64
65    def get_names(self):
66        d = {}
67        d.update(self.defs)
68        d.update(self.uses)
69        d.update(self.globals)
70        return d.keys()
71
72    def add_child(self, child):
73        self.children.append(child)
74
75    def get_children(self):
76        return self.children
77
78    def DEBUG(self):
79        print >> sys.stderr, self.name, self.nested and "nested" or ""
80        print >> sys.stderr, "\tglobals: ", self.globals
81        print >> sys.stderr, "\tcells: ", self.cells
82        print >> sys.stderr, "\tdefs: ", self.defs
83        print >> sys.stderr, "\tuses: ", self.uses
84        print >> sys.stderr, "\tfrees:", self.frees
85
86    def check_name(self, name):
87        """Return scope of name.
88
89        The scope of a name could be LOCAL, GLOBAL, FREE, or CELL.
90        """
91        if self.globals.has_key(name):
92            return SC_GLOBAL
93        if self.cells.has_key(name):
94            return SC_CELL
95        if self.defs.has_key(name):
96            return SC_LOCAL
97        if self.nested and (self.frees.has_key(name) or
98                            self.uses.has_key(name)):
99            return SC_FREE
100        if self.nested:
101            return SC_UNKNOWN
102        else:
103            return SC_GLOBAL
104
105    def get_free_vars(self):
106        if not self.nested:
107            return ()
108        free = {}
109        free.update(self.frees)
110        for name in self.uses.keys():
111            if not (self.defs.has_key(name) or
112                    self.globals.has_key(name)):
113                free[name] = 1
114        return free.keys()
115
116    def handle_children(self):
117        for child in self.children:
118            frees = child.get_free_vars()
119            globals = self.add_frees(frees)
120            for name in globals:
121                child.force_global(name)
122
123    def force_global(self, name):
124        """Force name to be global in scope.
125
126        Some child of the current node had a free reference to name.
127        When the child was processed, it was labelled a free
128        variable.  Now that all its enclosing scope have been
129        processed, the name is known to be a global or builtin.  So
130        walk back down the child chain and set the name to be global
131        rather than free.
132
133        Be careful to stop if a child does not think the name is
134        free.
135        """
136        self.globals[name] = 1
137        if self.frees.has_key(name):
138            del self.frees[name]
139        for child in self.children:
140            if child.check_name(name) == SC_FREE:
141                child.force_global(name)
142
143    def add_frees(self, names):
144        """Process list of free vars from nested scope.
145
146        Returns a list of names that are either 1) declared global in the
147        parent or 2) undefined in a top-level parent.  In either case,
148        the nested scope should treat them as globals.
149        """
150        child_globals = []
151        for name in names:
152            sc = self.check_name(name)
153            if self.nested:
154                if sc == SC_UNKNOWN or sc == SC_FREE \
155                   or isinstance(self, ClassScope):
156                    self.frees[name] = 1
157                elif sc == SC_GLOBAL:
158                    child_globals.append(name)
159                elif isinstance(self, FunctionScope) and sc == SC_LOCAL:
160                    self.cells[name] = 1
161                elif sc != SC_CELL:
162                    child_globals.append(name)
163            else:
164                if sc == SC_LOCAL:
165                    self.cells[name] = 1
166                elif sc != SC_CELL:
167                    child_globals.append(name)
168        return child_globals
169
170    def get_cell_vars(self):
171        return self.cells.keys()
172
173class ModuleScope(Scope):
174    __super_init = Scope.__init__
175
176    def __init__(self):
177        self.__super_init("global", self)
178
179class FunctionScope(Scope):
180    pass
181
182class LambdaScope(FunctionScope):
183    __super_init = Scope.__init__
184
185    __counter = 1
186
187    def __init__(self, module, klass=None):
188        i = self.__counter
189        self.__counter += 1
190        self.__super_init("lambda.%d" % i, module, klass)
191
192class ClassScope(Scope):
193    __super_init = Scope.__init__
194
195    def __init__(self, name, module):
196        self.__super_init(name, module, name)
197
198class SymbolVisitor:
199    def __init__(self):
200        self.scopes = {}
201        self.klass = None
202
203    # node that define new scopes
204
205    def visitModule(self, node):
206        scope = self.module = self.scopes[node] = ModuleScope()
207        self.visit(node.node, scope)
208
209    def visitFunction(self, node, parent):
210        parent.add_def(node.name)
211        for n in node.defaults:
212            self.visit(n, parent)
213        scope = FunctionScope(node.name, self.module, self.klass)
214        if parent.nested or isinstance(parent, FunctionScope):
215            scope.nested = 1
216        self.scopes[node] = scope
217        self._do_args(scope, node.argnames)
218        self.visit(node.code, scope)
219        self.handle_free_vars(scope, parent)
220
221    def visitLambda(self, node, parent):
222        for n in node.defaults:
223            self.visit(n, parent)
224        scope = LambdaScope(self.module, self.klass)
225        if parent.nested or isinstance(parent, FunctionScope):
226            scope.nested = 1
227        self.scopes[node] = scope
228        self._do_args(scope, node.argnames)
229        self.visit(node.code, scope)
230        self.handle_free_vars(scope, parent)
231
232    def _do_args(self, scope, args):
233        for name in args:
234            if type(name) == types.TupleType:
235                self._do_args(scope, name)
236            else:
237                scope.add_param(name)
238
239    def handle_free_vars(self, scope, parent):
240        parent.add_child(scope)
241        scope.handle_children()
242
243    def visitClass(self, node, parent):
244        parent.add_def(node.name)
245        for n in node.bases:
246            self.visit(n, parent)
247        scope = ClassScope(node.name, self.module)
248        if parent.nested or isinstance(parent, FunctionScope):
249            scope.nested = 1
250        self.scopes[node] = scope
251        prev = self.klass
252        self.klass = node.name
253        self.visit(node.code, scope)
254        self.klass = prev
255        self.handle_free_vars(scope, parent)
256
257    # name can be a def or a use
258
259    # XXX a few calls and nodes expect a third "assign" arg that is
260    # true if the name is being used as an assignment.  only
261    # expressions contained within statements may have the assign arg.
262
263    def visitName(self, node, scope, assign=0):
264        if assign:
265            scope.add_def(node.name)
266        else:
267            scope.add_use(node.name)
268
269    # operations that bind new names
270
271    def visitFor(self, node, scope):
272        self.visit(node.assign, scope, 1)
273        self.visit(node.list, scope)
274        self.visit(node.body, scope)
275        if node.else_:
276            self.visit(node.else_, scope)
277
278    def visitFrom(self, node, scope):
279        for name, asname in node.names:
280            if name == "*":
281                continue
282            scope.add_def(asname or name)
283
284    def visitImport(self, node, scope):
285        for name, asname in node.names:
286            i = name.find(".")
287            if i > -1:
288                name = name[:i]
289            scope.add_def(asname or name)
290
291    def visitGlobal(self, node, scope):
292        for name in node.names:
293            scope.add_global(name)
294
295    def visitAssign(self, node, scope):
296        """Propagate assignment flag down to child nodes.
297
298        The Assign node doesn't itself contains the variables being
299        assigned to.  Instead, the children in node.nodes are visited
300        with the assign flag set to true.  When the names occur in
301        those nodes, they are marked as defs.
302
303        Some names that occur in an assignment target are not bound by
304        the assignment, e.g. a name occurring inside a slice.  The
305        visitor handles these nodes specially; they do not propagate
306        the assign flag to their children.
307        """
308        for n in node.nodes:
309            self.visit(n, scope, 1)
310        self.visit(node.expr, scope)
311
312    def visitAssName(self, node, scope, assign=1):
313        scope.add_def(node.name)
314
315    def visitAssAttr(self, node, scope, assign=0):
316        self.visit(node.expr, scope, 0)
317
318    def visitSubscript(self, node, scope, assign=0):
319        self.visit(node.expr, scope, 0)
320        for n in node.subs:
321            self.visit(n, scope, 0)
322
323    def visitSlice(self, node, scope, assign=0):
324        self.visit(node.expr, scope, assign)
325        if node.lower:
326            self.visit(node.lower, scope, 0)
327        if node.upper:
328            self.visit(node.upper, scope, 0)
329
330    def visitAugAssign(self, node, scope):
331        # If the LHS is a name, then this counts as assignment.
332        # Otherwise, it's just use.
333        self.visit(node.node, scope)
334        if isinstance(node.node, ast.Name):
335            self.visit(node.node, scope, 1) # XXX worry about this
336        self.visit(node.expr, scope)
337
338    # prune if statements if tests are false
339
340    _const_types = types.StringType, types.IntType, types.FloatType
341
342    def visitIf(self, node, scope):
343        for test, body in node.tests:
344            if isinstance(test, ast.Const):
345                if type(test.value) in self._const_types:
346                    if not test.value:
347                        continue
348            self.visit(test, scope)
349            self.visit(body, scope)
350        if node.else_:
351            self.visit(node.else_, scope)
352
353    # a yield statement signals a generator
354
355    def visitYield(self, node, scope):
356        self.generator = 1
357        self.visit(node.value, scope)
358
359def sort(l):
360    l = l[:]
361    l.sort()
362    return l
363
364def list_eq(l1, l2):
365    return sort(l1) == sort(l2)
366
367if __name__ == "__main__":
368    import sys
369    from compiler import parseFile, walk
370    import symtable
371
372    def get_names(syms):
373        return [s for s in [s.get_name() for s in syms.get_symbols()]
374                if not (s.startswith('_[') or s.startswith('.'))]
375
376    for file in sys.argv[1:]:
377        print file
378        f = open(file)
379        buf = f.read()
380        f.close()
381        syms = symtable.symtable(buf, file, "exec")
382        mod_names = get_names(syms)
383        tree = parseFile(file)
384        s = SymbolVisitor()
385        walk(tree, s)
386
387        # compare module-level symbols
388        names2 = s.scopes[tree].get_names()
389
390        if not list_eq(mod_names, names2):
391            print
392            print "oops", file
393            print sort(mod_names)
394            print sort(names2)
395            sys.exit(-1)
396
397        d = {}
398        d.update(s.scopes)
399        del d[tree]
400        scopes = d.values()
401        del d
402
403        for s in syms.get_symbols():
404            if s.is_namespace():
405                l = [sc for sc in scopes
406                     if sc.name == s.get_name()]
407                if len(l) > 1:
408                    print "skipping", s.get_name()
409                else:
410                    if not list_eq(get_names(s.get_namespace()),
411                                   l[0].get_names()):
412                        print s.get_name()
413                        print sort(get_names(s.get_namespace()))
414                        print sort(l[0].get_names())
415                        sys.exit(-1)
416