pycodegen.py revision 364f9b9e2f798e4d28ed21122faffb030a6ccac5
1import imp
2import os
3import marshal
4import stat
5import string
6import struct
7import sys
8import types
9from cStringIO import StringIO
10
11from compiler import ast, parse, walk
12from compiler import pyassem, misc, future, symbols
13from compiler.consts import SC_LOCAL, SC_GLOBAL, SC_FREE, SC_CELL
14from compiler.pyassem import CO_VARARGS, CO_VARKEYWORDS, CO_NEWLOCALS,\
15     CO_NESTED, TupleArg
16
17# Do we have Python 1.x or Python 2.x?
18try:
19    VERSION = sys.version_info[0]
20except AttributeError:
21    VERSION = 1
22
23callfunc_opcode_info = {
24    # (Have *args, Have **args) : opcode
25    (0,0) : "CALL_FUNCTION",
26    (1,0) : "CALL_FUNCTION_VAR",
27    (0,1) : "CALL_FUNCTION_KW",
28    (1,1) : "CALL_FUNCTION_VAR_KW",
29}
30
31def compile(filename, display=0):
32    f = open(filename)
33    buf = f.read()
34    f.close()
35    mod = Module(buf, filename)
36    mod.compile(display)
37    f = open(filename + "c", "wb")
38    mod.dump(f)
39    f.close()
40
41class Module:
42    def __init__(self, source, filename):
43        self.filename = filename
44        self.source = source
45        self.code = None
46
47    def compile(self, display=0):
48        tree = parse(self.source)
49        root, filename = os.path.split(self.filename)
50        if "nested_scopes" in future.find_futures(tree):
51            gen = NestedScopeModuleCodeGenerator(filename)
52        else:
53            gen = ModuleCodeGenerator(filename)
54        walk(tree, gen, 1)
55        if display:
56            import pprint
57            print pprint.pprint(tree)
58        self.code = gen.getCode()
59
60    def dump(self, f):
61        f.write(self.getPycHeader())
62        marshal.dump(self.code, f)
63
64    MAGIC = imp.get_magic()
65
66    def getPycHeader(self):
67        # compile.c uses marshal to write a long directly, with
68        # calling the interface that would also generate a 1-byte code
69        # to indicate the type of the value.  simplest way to get the
70        # same effect is to call marshal and then skip the code.
71        mtime = os.stat(self.filename)[stat.ST_MTIME]
72        mtime = struct.pack('i', mtime)
73        return self.MAGIC + mtime
74
75class LocalNameFinder:
76    """Find local names in scope"""
77    def __init__(self, names=()):
78        self.names = misc.Set()
79        self.globals = misc.Set()
80        for name in names:
81            self.names.add(name)
82
83    # XXX list comprehensions and for loops
84
85    def getLocals(self):
86        for elt in self.globals.elements():
87            if self.names.has_elt(elt):
88                self.names.remove(elt)
89        return self.names
90
91    def visitDict(self, node):
92        pass
93
94    def visitGlobal(self, node):
95        for name in node.names:
96            self.globals.add(name)
97
98    def visitFunction(self, node):
99        self.names.add(node.name)
100
101    def visitLambda(self, node):
102        pass
103
104    def visitImport(self, node):
105        for name, alias in node.names:
106            self.names.add(alias or name)
107
108    def visitFrom(self, node):
109        for name, alias in node.names:
110            self.names.add(alias or name)
111
112    def visitClass(self, node):
113        self.names.add(node.name)
114
115    def visitAssName(self, node):
116        self.names.add(node.name)
117
118class CodeGenerator:
119    """Defines basic code generator for Python bytecode
120
121    This class is an abstract base class.  Concrete subclasses must
122    define an __init__() that defines self.graph and then calls the
123    __init__() defined in this class.
124
125    The concrete class must also define the class attributes
126    NameFinder, FunctionGen, and ClassGen.  These attributes can be
127    defined in the initClass() method, which is a hook for
128    initializing these methods after all the classes have been
129    defined.
130    """
131
132    optimized = 0 # is namespace access optimized?
133    __initialized = None
134
135    def __init__(self, filename):
136        if self.__initialized is None:
137            self.initClass()
138            self.__class__.__initialized = 1
139        self.checkClass()
140        self.filename = filename
141        self.locals = misc.Stack()
142        self.loops = misc.Stack()
143        self.curStack = 0
144        self.maxStack = 0
145        self.last_lineno = None
146        self._setupGraphDelegation()
147
148    def initClass(self):
149        """This method is called once for each class"""
150
151    def checkClass(self):
152        """Verify that class is constructed correctly"""
153        try:
154            assert hasattr(self, 'graph')
155            assert getattr(self, 'NameFinder')
156            assert getattr(self, 'FunctionGen')
157            assert getattr(self, 'ClassGen')
158        except AssertionError, msg:
159            intro = "Bad class construction for %s" % self.__class__.__name__
160            raise AssertionError, intro
161
162    def _setupGraphDelegation(self):
163        self.emit = self.graph.emit
164        self.newBlock = self.graph.newBlock
165        self.startBlock = self.graph.startBlock
166        self.nextBlock = self.graph.nextBlock
167        self.setDocstring = self.graph.setDocstring
168
169    def getCode(self):
170        """Return a code object"""
171        return self.graph.getCode()
172
173    # Next five methods handle name access
174
175    def isLocalName(self, name):
176        return self.locals.top().has_elt(name)
177
178    def storeName(self, name):
179        self._nameOp('STORE', name)
180
181    def loadName(self, name):
182        self._nameOp('LOAD', name)
183
184    def delName(self, name):
185        self._nameOp('DELETE', name)
186
187    def _nameOp(self, prefix, name):
188        if not self.optimized:
189            self.emit(prefix + '_NAME', name)
190            return
191        if self.isLocalName(name):
192            self.emit(prefix + '_FAST', name)
193        else:
194            self.emit(prefix + '_GLOBAL', name)
195
196    def set_lineno(self, node):
197        """Emit SET_LINENO if node has lineno attribute and it is
198        different than the last lineno emitted.
199
200        Returns true if SET_LINENO was emitted.
201
202        There are no rules for when an AST node should have a lineno
203        attribute.  The transformer and AST code need to be reviewed
204        and a consistent policy implemented and documented.  Until
205        then, this method works around missing line numbers.
206        """
207        lineno = getattr(node, 'lineno', None)
208        if lineno is not None and lineno != self.last_lineno:
209            self.emit('SET_LINENO', lineno)
210            self.last_lineno = lineno
211            return 1
212        return 0
213
214    # The first few visitor methods handle nodes that generator new
215    # code objects.  They use class attributes to determine what
216    # specialized code generators to use.
217
218    NameFinder = LocalNameFinder
219    FunctionGen = None
220    ClassGen = None
221
222    def visitModule(self, node):
223        lnf = walk(node.node, self.NameFinder(), 0)
224        self.locals.push(lnf.getLocals())
225        if node.doc:
226            self.fixDocstring(node.node)
227        self.visit(node.node)
228        self.emit('LOAD_CONST', None)
229        self.emit('RETURN_VALUE')
230
231    def visitFunction(self, node):
232        self._visitFuncOrLambda(node, isLambda=0)
233        if node.doc:
234            self.setDocstring(node.doc)
235        self.storeName(node.name)
236
237    def visitLambda(self, node):
238        self._visitFuncOrLambda(node, isLambda=1)
239
240    def _visitFuncOrLambda(self, node, isLambda=0):
241        gen = self.FunctionGen(node, self.filename, self.scopes, isLambda)
242        walk(node.code, gen)
243        gen.finish()
244        self.set_lineno(node)
245        for default in node.defaults:
246            self.visit(default)
247        self.emit('LOAD_CONST', gen)
248        self.emit('MAKE_FUNCTION', len(node.defaults))
249
250    def visitClass(self, node):
251        gen = self.ClassGen(node, self.filename, self.scopes)
252        if node.doc:
253            self.fixDocstring(node.code)
254        walk(node.code, gen)
255        gen.finish()
256        self.set_lineno(node)
257        self.emit('LOAD_CONST', node.name)
258        for base in node.bases:
259            self.visit(base)
260        self.emit('BUILD_TUPLE', len(node.bases))
261        self.emit('LOAD_CONST', gen)
262        self.emit('MAKE_FUNCTION', 0)
263        self.emit('CALL_FUNCTION', 0)
264        self.emit('BUILD_CLASS')
265        self.storeName(node.name)
266
267    def fixDocstring(self, node):
268        """Rewrite the ast for a class with a docstring.
269
270        The AST includes a Discard(Const(docstring)) node.  Replace
271        this with an Assign([AssName('__doc__', ...])
272        """
273        assert isinstance(node, ast.Stmt)
274        stmts = node.nodes
275        discard = stmts[0]
276        assert isinstance(discard, ast.Discard)
277        stmts[0] = ast.Assign([ast.AssName('__doc__', 'OP_ASSIGN')],
278                              discard.expr)
279        stmts[0].lineno = discard.lineno
280    # The rest are standard visitor methods
281
282    # The next few implement control-flow statements
283
284    def visitIf(self, node):
285        end = self.newBlock()
286        numtests = len(node.tests)
287        for i in range(numtests):
288            test, suite = node.tests[i]
289            self.set_lineno(test)
290            self.visit(test)
291            nextTest = self.newBlock()
292            self.emit('JUMP_IF_FALSE', nextTest)
293            self.nextBlock()
294            self.emit('POP_TOP')
295            self.visit(suite)
296            self.emit('JUMP_FORWARD', end)
297            self.startBlock(nextTest)
298            self.emit('POP_TOP')
299        if node.else_:
300            self.visit(node.else_)
301        self.nextBlock(end)
302
303    def visitWhile(self, node):
304        self.set_lineno(node)
305
306        loop = self.newBlock()
307        else_ = self.newBlock()
308
309        after = self.newBlock()
310        self.emit('SETUP_LOOP', after)
311
312        self.nextBlock(loop)
313        self.loops.push(loop)
314
315        self.set_lineno(node)
316        self.visit(node.test)
317        self.emit('JUMP_IF_FALSE', else_ or after)
318
319        self.nextBlock()
320        self.emit('POP_TOP')
321        self.visit(node.body)
322        self.emit('JUMP_ABSOLUTE', loop)
323
324        self.startBlock(else_) # or just the POPs if not else clause
325        self.emit('POP_TOP')
326        self.emit('POP_BLOCK')
327        if node.else_:
328            self.visit(node.else_)
329        self.loops.pop()
330        self.nextBlock(after)
331
332    def visitFor(self, node):
333        start = self.newBlock()
334        anchor = self.newBlock()
335        after = self.newBlock()
336        self.loops.push(start)
337
338        self.set_lineno(node)
339        self.emit('SETUP_LOOP', after)
340        self.visit(node.list)
341        self.visit(ast.Const(0))
342        self.nextBlock(start)
343        self.set_lineno(node)
344        self.emit('FOR_LOOP', anchor)
345        self.nextBlock()
346        self.visit(node.assign)
347        self.visit(node.body)
348        self.emit('JUMP_ABSOLUTE', start)
349        self.startBlock(anchor)
350        self.emit('POP_BLOCK')
351        if node.else_:
352            self.visit(node.else_)
353        self.loops.pop()
354        self.nextBlock(after)
355
356    def visitBreak(self, node):
357        if not self.loops:
358            raise SyntaxError, "'break' outside loop (%s, %d)" % \
359                  (self.filename, node.lineno)
360        self.set_lineno(node)
361        self.emit('BREAK_LOOP')
362
363    def visitContinue(self, node):
364        if not self.loops:
365            raise SyntaxError, "'continue' outside loop (%s, %d)" % \
366                  (self.filename, node.lineno)
367        l = self.loops.top()
368        self.set_lineno(node)
369        self.emit('JUMP_ABSOLUTE', l)
370        self.nextBlock()
371
372    def visitTest(self, node, jump):
373        end = self.newBlock()
374        for child in node.nodes[:-1]:
375            self.visit(child)
376            self.emit(jump, end)
377            self.nextBlock()
378            self.emit('POP_TOP')
379        self.visit(node.nodes[-1])
380        self.nextBlock(end)
381
382    def visitAnd(self, node):
383        self.visitTest(node, 'JUMP_IF_FALSE')
384
385    def visitOr(self, node):
386        self.visitTest(node, 'JUMP_IF_TRUE')
387
388    def visitCompare(self, node):
389        self.visit(node.expr)
390        cleanup = self.newBlock()
391        for op, code in node.ops[:-1]:
392            self.visit(code)
393            self.emit('DUP_TOP')
394            self.emit('ROT_THREE')
395            self.emit('COMPARE_OP', op)
396            self.emit('JUMP_IF_FALSE', cleanup)
397            self.nextBlock()
398            self.emit('POP_TOP')
399        # now do the last comparison
400        if node.ops:
401            op, code = node.ops[-1]
402            self.visit(code)
403            self.emit('COMPARE_OP', op)
404        if len(node.ops) > 1:
405            end = self.newBlock()
406            self.emit('JUMP_FORWARD', end)
407            self.startBlock(cleanup)
408            self.emit('ROT_TWO')
409            self.emit('POP_TOP')
410            self.nextBlock(end)
411
412    # list comprehensions
413    __list_count = 0
414
415    def visitListComp(self, node):
416        # XXX would it be easier to transform the AST into the form it
417        # would have if the list comp were expressed as a series of
418        # for and if stmts and an explicit append?
419        self.set_lineno(node)
420        # setup list
421        append = "$append%d" % self.__list_count
422        self.__list_count = self.__list_count + 1
423        self.emit('BUILD_LIST', 0)
424        self.emit('DUP_TOP')
425        self.emit('LOAD_ATTR', 'append')
426        self.storeName(append)
427        l = len(node.quals)
428        stack = []
429        for i, for_ in zip(range(l), node.quals):
430            start, anchor = self.visit(for_)
431            cont = None
432            for if_ in for_.ifs:
433                if cont is None:
434                    cont = self.newBlock()
435                self.visit(if_, cont)
436            stack.insert(0, (start, cont, anchor))
437
438        self.loadName(append)
439        self.visit(node.expr)
440        self.emit('CALL_FUNCTION', 1)
441        self.emit('POP_TOP')
442
443        for start, cont, anchor in stack:
444            if cont:
445                skip_one = self.newBlock()
446                self.emit('JUMP_FORWARD', skip_one)
447                self.startBlock(cont)
448                self.emit('POP_TOP')
449                self.nextBlock(skip_one)
450            self.emit('JUMP_ABSOLUTE', start)
451            self.startBlock(anchor)
452        self.delName(append)
453
454        self.__list_count = self.__list_count - 1
455
456    def visitListCompFor(self, node):
457        self.set_lineno(node)
458        start = self.newBlock()
459        anchor = self.newBlock()
460
461        self.visit(node.list)
462        self.visit(ast.Const(0))
463        self.emit('SET_LINENO', node.lineno)
464        self.nextBlock(start)
465        self.emit('FOR_LOOP', anchor)
466        self.nextBlock()
467        self.visit(node.assign)
468        return start, anchor
469
470    def visitListCompIf(self, node, branch):
471        self.set_lineno(node)
472        self.visit(node.test)
473        self.emit('JUMP_IF_FALSE', branch)
474        self.newBlock()
475        self.emit('POP_TOP')
476
477    # exception related
478
479    def visitAssert(self, node):
480        # XXX would be interesting to implement this via a
481        # transformation of the AST before this stage
482        end = self.newBlock()
483        self.set_lineno(node)
484        # XXX __debug__ and AssertionError appear to be special cases
485        # -- they are always loaded as globals even if there are local
486        # names.  I guess this is a sort of renaming op.
487        self.emit('LOAD_GLOBAL', '__debug__')
488        self.emit('JUMP_IF_FALSE', end)
489        self.nextBlock()
490        self.emit('POP_TOP')
491        self.visit(node.test)
492        self.emit('JUMP_IF_TRUE', end)
493        self.nextBlock()
494        self.emit('POP_TOP')
495        self.emit('LOAD_GLOBAL', 'AssertionError')
496        if node.fail:
497            self.visit(node.fail)
498            self.emit('RAISE_VARARGS', 2)
499        else:
500            self.emit('RAISE_VARARGS', 1)
501        self.nextBlock(end)
502        self.emit('POP_TOP')
503
504    def visitRaise(self, node):
505        self.set_lineno(node)
506        n = 0
507        if node.expr1:
508            self.visit(node.expr1)
509            n = n + 1
510        if node.expr2:
511            self.visit(node.expr2)
512            n = n + 1
513        if node.expr3:
514            self.visit(node.expr3)
515            n = n + 1
516        self.emit('RAISE_VARARGS', n)
517
518    def visitTryExcept(self, node):
519        handlers = self.newBlock()
520        end = self.newBlock()
521        if node.else_:
522            lElse = self.newBlock()
523        else:
524            lElse = end
525        self.set_lineno(node)
526        self.emit('SETUP_EXCEPT', handlers)
527        self.nextBlock()
528        self.visit(node.body)
529        self.emit('POP_BLOCK')
530        self.emit('JUMP_FORWARD', lElse)
531        self.startBlock(handlers)
532
533        last = len(node.handlers) - 1
534        for i in range(len(node.handlers)):
535            expr, target, body = node.handlers[i]
536            self.set_lineno(expr)
537            if expr:
538                self.emit('DUP_TOP')
539                self.visit(expr)
540                self.emit('COMPARE_OP', 'exception match')
541                next = self.newBlock()
542                self.emit('JUMP_IF_FALSE', next)
543                self.nextBlock()
544                self.emit('POP_TOP')
545            self.emit('POP_TOP')
546            if target:
547                self.visit(target)
548            else:
549                self.emit('POP_TOP')
550            self.emit('POP_TOP')
551            self.visit(body)
552            self.emit('JUMP_FORWARD', end)
553            if expr:
554                self.nextBlock(next)
555            else:
556                self.nextBlock()
557            self.emit('POP_TOP')
558        self.emit('END_FINALLY')
559        if node.else_:
560            self.nextBlock(lElse)
561            self.visit(node.else_)
562        self.nextBlock(end)
563
564    def visitTryFinally(self, node):
565        final = self.newBlock()
566        self.set_lineno(node)
567        self.emit('SETUP_FINALLY', final)
568        self.nextBlock()
569        self.visit(node.body)
570        self.emit('POP_BLOCK')
571        self.emit('LOAD_CONST', None)
572        self.nextBlock(final)
573        self.visit(node.final)
574        self.emit('END_FINALLY')
575
576    # misc
577
578    def visitDiscard(self, node):
579        self.set_lineno(node)
580        self.visit(node.expr)
581        self.emit('POP_TOP')
582
583    def visitConst(self, node):
584        self.emit('LOAD_CONST', node.value)
585
586    def visitKeyword(self, node):
587        self.emit('LOAD_CONST', node.name)
588        self.visit(node.expr)
589
590    def visitGlobal(self, node):
591        # no code to generate
592        pass
593
594    def visitName(self, node):
595        self.set_lineno(node)
596        self.loadName(node.name)
597
598    def visitPass(self, node):
599        self.set_lineno(node)
600
601    def visitImport(self, node):
602        self.set_lineno(node)
603        for name, alias in node.names:
604            if VERSION > 1:
605                self.emit('LOAD_CONST', None)
606            self.emit('IMPORT_NAME', name)
607            mod = string.split(name, ".")[0]
608            self.storeName(alias or mod)
609
610    def visitFrom(self, node):
611        self.set_lineno(node)
612        fromlist = map(lambda (name, alias): name, node.names)
613        if VERSION > 1:
614            self.emit('LOAD_CONST', tuple(fromlist))
615        self.emit('IMPORT_NAME', node.modname)
616        for name, alias in node.names:
617            if VERSION > 1:
618                if name == '*':
619                    self.namespace = 0
620                    self.emit('IMPORT_STAR')
621                    # There can only be one name w/ from ... import *
622                    assert len(node.names) == 1
623                    return
624                else:
625                    self.emit('IMPORT_FROM', name)
626                    self._resolveDots(name)
627                    self.storeName(alias or name)
628            else:
629                self.emit('IMPORT_FROM', name)
630        self.emit('POP_TOP')
631
632    def _resolveDots(self, name):
633        elts = string.split(name, ".")
634        if len(elts) == 1:
635            return
636        for elt in elts[1:]:
637            self.emit('LOAD_ATTR', elt)
638
639    def visitGetattr(self, node):
640        self.visit(node.expr)
641        self.emit('LOAD_ATTR', node.attrname)
642
643    # next five implement assignments
644
645    def visitAssign(self, node):
646        self.set_lineno(node)
647        self.visit(node.expr)
648        dups = len(node.nodes) - 1
649        for i in range(len(node.nodes)):
650            elt = node.nodes[i]
651            if i < dups:
652                self.emit('DUP_TOP')
653            if isinstance(elt, ast.Node):
654                self.visit(elt)
655
656    def visitAssName(self, node):
657        if node.flags == 'OP_ASSIGN':
658            self.storeName(node.name)
659        elif node.flags == 'OP_DELETE':
660            self.delName(node.name)
661        else:
662            print "oops", node.flags
663
664    def visitAssAttr(self, node):
665        self.visit(node.expr)
666        if node.flags == 'OP_ASSIGN':
667            self.emit('STORE_ATTR', node.attrname)
668        elif node.flags == 'OP_DELETE':
669            self.emit('DELETE_ATTR', node.attrname)
670        else:
671            print "warning: unexpected flags:", node.flags
672            print node
673
674    def _visitAssSequence(self, node, op='UNPACK_SEQUENCE'):
675        if findOp(node) != 'OP_DELETE':
676            self.emit(op, len(node.nodes))
677        for child in node.nodes:
678            self.visit(child)
679
680    if VERSION > 1:
681        visitAssTuple = _visitAssSequence
682        visitAssList = _visitAssSequence
683    else:
684        def visitAssTuple(self, node):
685            self._visitAssSequence(node, 'UNPACK_TUPLE')
686
687        def visitAssList(self, node):
688            self._visitAssSequence(node, 'UNPACK_LIST')
689
690    # augmented assignment
691
692    def visitAugAssign(self, node):
693        aug_node = wrap_aug(node.node)
694        self.visit(aug_node, "load")
695        self.visit(node.expr)
696        self.emit(self._augmented_opcode[node.op])
697        self.visit(aug_node, "store")
698
699    _augmented_opcode = {
700        '+=' : 'INPLACE_ADD',
701        '-=' : 'INPLACE_SUBTRACT',
702        '*=' : 'INPLACE_MULTIPLY',
703        '/=' : 'INPLACE_DIVIDE',
704        '%=' : 'INPLACE_MODULO',
705        '**=': 'INPLACE_POWER',
706        '>>=': 'INPLACE_RSHIFT',
707        '<<=': 'INPLACE_LSHIFT',
708        '&=' : 'INPLACE_AND',
709        '^=' : 'INPLACE_XOR',
710        '|=' : 'INPLACE_OR',
711        }
712
713    def visitAugName(self, node, mode):
714        if mode == "load":
715            self.loadName(node.name)
716        elif mode == "store":
717            self.storeName(node.name)
718
719    def visitAugGetattr(self, node, mode):
720        if mode == "load":
721            self.visit(node.expr)
722            self.emit('DUP_TOP')
723            self.emit('LOAD_ATTR', node.attrname)
724        elif mode == "store":
725            self.emit('ROT_TWO')
726            self.emit('STORE_ATTR', node.attrname)
727
728    def visitAugSlice(self, node, mode):
729        if mode == "load":
730            self.visitSlice(node, 1)
731        elif mode == "store":
732            slice = 0
733            if node.lower:
734                slice = slice | 1
735            if node.upper:
736                slice = slice | 2
737            if slice == 0:
738                self.emit('ROT_TWO')
739            elif slice == 3:
740                self.emit('ROT_FOUR')
741            else:
742                self.emit('ROT_THREE')
743            self.emit('STORE_SLICE+%d' % slice)
744
745    def visitAugSubscript(self, node, mode):
746        if len(node.subs) > 1:
747            raise SyntaxError, "augmented assignment to tuple is not possible"
748        if mode == "load":
749            self.visitSubscript(node, 1)
750        elif mode == "store":
751            self.emit('ROT_THREE')
752            self.emit('STORE_SUBSCR')
753
754    def visitExec(self, node):
755        self.visit(node.expr)
756        if node.locals is None:
757            self.emit('LOAD_CONST', None)
758        else:
759            self.visit(node.locals)
760        if node.globals is None:
761            self.emit('DUP_TOP')
762        else:
763            self.visit(node.globals)
764        self.emit('EXEC_STMT')
765
766    def visitCallFunc(self, node):
767        pos = 0
768        kw = 0
769        self.set_lineno(node)
770        self.visit(node.node)
771        for arg in node.args:
772            self.visit(arg)
773            if isinstance(arg, ast.Keyword):
774                kw = kw + 1
775            else:
776                pos = pos + 1
777        if node.star_args is not None:
778            self.visit(node.star_args)
779        if node.dstar_args is not None:
780            self.visit(node.dstar_args)
781        have_star = node.star_args is not None
782        have_dstar = node.dstar_args is not None
783        opcode = callfunc_opcode_info[have_star, have_dstar]
784        self.emit(opcode, kw << 8 | pos)
785
786    def visitPrint(self, node):
787        self.set_lineno(node)
788        if node.dest:
789            self.visit(node.dest)
790        for child in node.nodes:
791            if node.dest:
792                self.emit('DUP_TOP')
793            self.visit(child)
794            if node.dest:
795                self.emit('ROT_TWO')
796                self.emit('PRINT_ITEM_TO')
797            else:
798                self.emit('PRINT_ITEM')
799
800    def visitPrintnl(self, node):
801        self.visitPrint(node)
802        if node.dest:
803            self.emit('PRINT_NEWLINE_TO')
804        else:
805            self.emit('PRINT_NEWLINE')
806
807    def visitReturn(self, node):
808        self.set_lineno(node)
809        self.visit(node.value)
810        self.emit('RETURN_VALUE')
811
812    # slice and subscript stuff
813
814    def visitSlice(self, node, aug_flag=None):
815        # aug_flag is used by visitAugSlice
816        self.visit(node.expr)
817        slice = 0
818        if node.lower:
819            self.visit(node.lower)
820            slice = slice | 1
821        if node.upper:
822            self.visit(node.upper)
823            slice = slice | 2
824        if aug_flag:
825            if slice == 0:
826                self.emit('DUP_TOP')
827            elif slice == 3:
828                self.emit('DUP_TOPX', 3)
829            else:
830                self.emit('DUP_TOPX', 2)
831        if node.flags == 'OP_APPLY':
832            self.emit('SLICE+%d' % slice)
833        elif node.flags == 'OP_ASSIGN':
834            self.emit('STORE_SLICE+%d' % slice)
835        elif node.flags == 'OP_DELETE':
836            self.emit('DELETE_SLICE+%d' % slice)
837        else:
838            print "weird slice", node.flags
839            raise
840
841    def visitSubscript(self, node, aug_flag=None):
842        self.visit(node.expr)
843        for sub in node.subs:
844            self.visit(sub)
845        if aug_flag:
846            self.emit('DUP_TOPX', 2)
847        if len(node.subs) > 1:
848            self.emit('BUILD_TUPLE', len(node.subs))
849        if node.flags == 'OP_APPLY':
850            self.emit('BINARY_SUBSCR')
851        elif node.flags == 'OP_ASSIGN':
852            self.emit('STORE_SUBSCR')
853        elif node.flags == 'OP_DELETE':
854            self.emit('DELETE_SUBSCR')
855
856    # binary ops
857
858    def binaryOp(self, node, op):
859        self.visit(node.left)
860        self.visit(node.right)
861        self.emit(op)
862
863    def visitAdd(self, node):
864        return self.binaryOp(node, 'BINARY_ADD')
865
866    def visitSub(self, node):
867        return self.binaryOp(node, 'BINARY_SUBTRACT')
868
869    def visitMul(self, node):
870        return self.binaryOp(node, 'BINARY_MULTIPLY')
871
872    def visitDiv(self, node):
873        return self.binaryOp(node, 'BINARY_DIVIDE')
874
875    def visitMod(self, node):
876        return self.binaryOp(node, 'BINARY_MODULO')
877
878    def visitPower(self, node):
879        return self.binaryOp(node, 'BINARY_POWER')
880
881    def visitLeftShift(self, node):
882        return self.binaryOp(node, 'BINARY_LSHIFT')
883
884    def visitRightShift(self, node):
885        return self.binaryOp(node, 'BINARY_RSHIFT')
886
887    # unary ops
888
889    def unaryOp(self, node, op):
890        self.visit(node.expr)
891        self.emit(op)
892
893    def visitInvert(self, node):
894        return self.unaryOp(node, 'UNARY_INVERT')
895
896    def visitUnarySub(self, node):
897        return self.unaryOp(node, 'UNARY_NEGATIVE')
898
899    def visitUnaryAdd(self, node):
900        return self.unaryOp(node, 'UNARY_POSITIVE')
901
902    def visitUnaryInvert(self, node):
903        return self.unaryOp(node, 'UNARY_INVERT')
904
905    def visitNot(self, node):
906        return self.unaryOp(node, 'UNARY_NOT')
907
908    def visitBackquote(self, node):
909        return self.unaryOp(node, 'UNARY_CONVERT')
910
911    # bit ops
912
913    def bitOp(self, nodes, op):
914        self.visit(nodes[0])
915        for node in nodes[1:]:
916            self.visit(node)
917            self.emit(op)
918
919    def visitBitand(self, node):
920        return self.bitOp(node.nodes, 'BINARY_AND')
921
922    def visitBitor(self, node):
923        return self.bitOp(node.nodes, 'BINARY_OR')
924
925    def visitBitxor(self, node):
926        return self.bitOp(node.nodes, 'BINARY_XOR')
927
928    # object constructors
929
930    def visitEllipsis(self, node):
931        self.emit('LOAD_CONST', Ellipsis)
932
933    def visitTuple(self, node):
934        for elt in node.nodes:
935            self.visit(elt)
936        self.emit('BUILD_TUPLE', len(node.nodes))
937
938    def visitList(self, node):
939        for elt in node.nodes:
940            self.visit(elt)
941        self.emit('BUILD_LIST', len(node.nodes))
942
943    def visitSliceobj(self, node):
944        for child in node.nodes:
945            self.visit(child)
946        self.emit('BUILD_SLICE', len(node.nodes))
947
948    def visitDict(self, node):
949        lineno = getattr(node, 'lineno', None)
950        if lineno:
951            set.emit('SET_LINENO', lineno)
952        self.emit('BUILD_MAP', 0)
953        for k, v in node.items:
954            lineno2 = getattr(node, 'lineno', None)
955            if lineno2 is not None and lineno != lineno2:
956                self.emit('SET_LINENO', lineno2)
957                lineno = lineno2
958            self.emit('DUP_TOP')
959            self.visit(v)
960            self.emit('ROT_TWO')
961            self.visit(k)
962            self.emit('STORE_SUBSCR')
963
964class NestedScopeCodeGenerator(CodeGenerator):
965    __super_visitModule = CodeGenerator.visitModule
966    __super_visitClass = CodeGenerator.visitClass
967    __super__visitFuncOrLambda = CodeGenerator._visitFuncOrLambda
968
969    def parseSymbols(self, tree):
970        s = symbols.SymbolVisitor()
971        walk(tree, s)
972        return s.scopes
973
974    def visitModule(self, node):
975        self.scopes = self.parseSymbols(node)
976        self.scope = self.scopes[node]
977        self.__super_visitModule(node)
978
979    def _nameOp(self, prefix, name):
980        scope = self.scope.check_name(name)
981        if scope == SC_LOCAL:
982            if not self.optimized:
983                self.emit(prefix + '_NAME', name)
984            else:
985                self.emit(prefix + '_FAST', name)
986        elif scope == SC_GLOBAL:
987            self.emit(prefix + '_GLOBAL', name)
988        elif scope == SC_FREE or scope == SC_CELL:
989            self.emit(prefix + '_DEREF', name)
990        else:
991            raise RuntimeError, "unsupported scope for var %s: %d" % \
992                  (name, scope)
993
994    def _visitFuncOrLambda(self, node, isLambda=0):
995        gen = self.FunctionGen(node, self.filename, self.scopes, isLambda)
996        walk(node.code, gen)
997        gen.finish()
998        self.set_lineno(node)
999        for default in node.defaults:
1000            self.visit(default)
1001        frees = gen.scope.get_free_vars()
1002        if frees:
1003            for name in frees:
1004                self.emit('LOAD_CLOSURE', name)
1005            self.emit('LOAD_CONST', gen)
1006            self.emit('MAKE_CLOSURE', len(node.defaults))
1007        else:
1008            self.emit('LOAD_CONST', gen)
1009            self.emit('MAKE_FUNCTION', len(node.defaults))
1010
1011    def visitClass(self, node):
1012        gen = self.ClassGen(node, self.filename, self.scopes)
1013        if node.doc:
1014            self.fixDocstring(node.code)
1015        walk(node.code, gen)
1016        gen.finish()
1017        self.set_lineno(node)
1018        self.emit('LOAD_CONST', node.name)
1019        for base in node.bases:
1020            self.visit(base)
1021        self.emit('BUILD_TUPLE', len(node.bases))
1022        frees = gen.scope.get_free_vars()
1023        for name in frees:
1024            self.emit('LOAD_CLOSURE', name)
1025        self.emit('LOAD_CONST', gen)
1026        if frees:
1027            self.emit('MAKE_CLOSURE', 0)
1028        else:
1029            self.emit('MAKE_FUNCTION', 0)
1030        self.emit('CALL_FUNCTION', 0)
1031        self.emit('BUILD_CLASS')
1032        self.storeName(node.name)
1033
1034
1035class LGBScopeMixin:
1036    """Defines initClass() for Python 2.1-compatible scoping"""
1037    def initClass(self):
1038        self.__class__.NameFinder = LocalNameFinder
1039        self.__class__.FunctionGen = FunctionCodeGenerator
1040        self.__class__.ClassGen = ClassCodeGenerator
1041
1042class NestedScopeMixin:
1043    """Defines initClass() for nested scoping (Python 2.2-compatible)"""
1044    def initClass(self):
1045        self.__class__.NameFinder = LocalNameFinder
1046        self.__class__.FunctionGen = NestedFunctionCodeGenerator
1047        self.__class__.ClassGen = NestedClassCodeGenerator
1048
1049class ModuleCodeGenerator(LGBScopeMixin, CodeGenerator):
1050    __super_init = CodeGenerator.__init__
1051
1052    scopes = None
1053
1054    def __init__(self, filename):
1055        self.graph = pyassem.PyFlowGraph("<module>", filename)
1056        self.__super_init(filename)
1057
1058class NestedScopeModuleCodeGenerator(NestedScopeMixin,
1059                                     NestedScopeCodeGenerator):
1060    __super_init = CodeGenerator.__init__
1061
1062    def __init__(self, filename):
1063        self.graph = pyassem.PyFlowGraph("<module>", filename)
1064        self.__super_init(filename)
1065        self.graph.setFlag(CO_NESTED)
1066
1067class AbstractFunctionCode:
1068    optimized = 1
1069    lambdaCount = 0
1070
1071    def __init__(self, func, filename, scopes, isLambda):
1072        if isLambda:
1073            klass = FunctionCodeGenerator
1074            name = "<lambda.%d>" % klass.lambdaCount
1075            klass.lambdaCount = klass.lambdaCount + 1
1076        else:
1077            name = func.name
1078        args, hasTupleArg = generateArgList(func.argnames)
1079        self.graph = pyassem.PyFlowGraph(name, filename, args,
1080                                         optimized=1)
1081        self.isLambda = isLambda
1082        self.super_init(filename)
1083
1084        if not isLambda and func.doc:
1085            self.setDocstring(func.doc)
1086
1087        lnf = walk(func.code, self.NameFinder(args), 0)
1088        self.locals.push(lnf.getLocals())
1089        if func.varargs:
1090            self.graph.setFlag(CO_VARARGS)
1091        if func.kwargs:
1092            self.graph.setFlag(CO_VARKEYWORDS)
1093        self.set_lineno(func)
1094        if hasTupleArg:
1095            self.generateArgUnpack(func.argnames)
1096
1097    def finish(self):
1098        self.graph.startExitBlock()
1099        if not self.isLambda:
1100            self.emit('LOAD_CONST', None)
1101        self.emit('RETURN_VALUE')
1102
1103    def generateArgUnpack(self, args):
1104        count = 0
1105        for arg in args:
1106            if type(arg) == types.TupleType:
1107                self.emit('LOAD_FAST', '.nested%d' % count)
1108                count = count + 1
1109                self.unpackSequence(arg)
1110
1111    def unpackSequence(self, tup):
1112        if VERSION > 1:
1113            self.emit('UNPACK_SEQUENCE', len(tup))
1114        else:
1115            self.emit('UNPACK_TUPLE', len(tup))
1116        for elt in tup:
1117            if type(elt) == types.TupleType:
1118                self.unpackSequence(elt)
1119            else:
1120                self.emit('STORE_FAST', elt)
1121
1122    unpackTuple = unpackSequence
1123
1124class FunctionCodeGenerator(LGBScopeMixin, AbstractFunctionCode,
1125                            CodeGenerator):
1126    super_init = CodeGenerator.__init__ # call be other init
1127    scopes = None
1128
1129class NestedFunctionCodeGenerator(AbstractFunctionCode,
1130                                  NestedScopeMixin,
1131                                  NestedScopeCodeGenerator):
1132    super_init = NestedScopeCodeGenerator.__init__ # call be other init
1133    __super_init = AbstractFunctionCode.__init__
1134
1135    def __init__(self, func, filename, scopes, isLambda):
1136        self.scopes = scopes
1137        self.scope = scopes[func]
1138        self.__super_init(func, filename, scopes, isLambda)
1139        self.graph.setFreeVars(self.scope.get_free_vars())
1140        self.graph.setCellVars(self.scope.get_cell_vars())
1141        self.graph.setFlag(CO_NESTED)
1142
1143class AbstractClassCode:
1144
1145    def __init__(self, klass, filename, scopes):
1146        self.graph = pyassem.PyFlowGraph(klass.name, filename,
1147                                           optimized=0)
1148        self.super_init(filename)
1149        lnf = walk(klass.code, self.NameFinder(), 0)
1150        self.locals.push(lnf.getLocals())
1151        self.graph.setFlag(CO_NEWLOCALS)
1152        if klass.doc:
1153            self.setDocstring(klass.doc)
1154
1155    def finish(self):
1156        self.graph.startExitBlock()
1157        self.emit('LOAD_LOCALS')
1158        self.emit('RETURN_VALUE')
1159
1160class ClassCodeGenerator(LGBScopeMixin, AbstractClassCode, CodeGenerator):
1161    super_init = CodeGenerator.__init__
1162    scopes = None
1163
1164class NestedClassCodeGenerator(AbstractClassCode,
1165                               NestedScopeMixin,
1166                               NestedScopeCodeGenerator):
1167    super_init = NestedScopeCodeGenerator.__init__ # call be other init
1168    __super_init = AbstractClassCode.__init__
1169
1170    def __init__(self, klass, filename, scopes):
1171        self.scopes = scopes
1172        self.scope = scopes[klass]
1173        self.__super_init(klass, filename, scopes)
1174        self.graph.setFreeVars(self.scope.get_free_vars())
1175        self.graph.setCellVars(self.scope.get_cell_vars())
1176        self.graph.setFlag(CO_NESTED)
1177
1178def generateArgList(arglist):
1179    """Generate an arg list marking TupleArgs"""
1180    args = []
1181    extra = []
1182    count = 0
1183    for elt in arglist:
1184        if type(elt) == types.StringType:
1185            args.append(elt)
1186        elif type(elt) == types.TupleType:
1187            args.append(TupleArg(count, elt))
1188            count = count + 1
1189            extra.extend(misc.flatten(elt))
1190        else:
1191            raise ValueError, "unexpect argument type:", elt
1192    return args + extra, count
1193
1194def findOp(node):
1195    """Find the op (DELETE, LOAD, STORE) in an AssTuple tree"""
1196    v = OpFinder()
1197    walk(node, v, 0)
1198    return v.op
1199
1200class OpFinder:
1201    def __init__(self):
1202        self.op = None
1203    def visitAssName(self, node):
1204        if self.op is None:
1205            self.op = node.flags
1206        elif self.op != node.flags:
1207            raise ValueError, "mixed ops in stmt"
1208
1209class Delegator:
1210    """Base class to support delegation for augmented assignment nodes
1211
1212    To generator code for augmented assignments, we use the following
1213    wrapper classes.  In visitAugAssign, the left-hand expression node
1214    is visited twice.  The first time the visit uses the normal method
1215    for that node .  The second time the visit uses a different method
1216    that generates the appropriate code to perform the assignment.
1217    These delegator classes wrap the original AST nodes in order to
1218    support the variant visit methods.
1219    """
1220    def __init__(self, obj):
1221        self.obj = obj
1222
1223    def __getattr__(self, attr):
1224        return getattr(self.obj, attr)
1225
1226class AugGetattr(Delegator):
1227    pass
1228
1229class AugName(Delegator):
1230    pass
1231
1232class AugSlice(Delegator):
1233    pass
1234
1235class AugSubscript(Delegator):
1236    pass
1237
1238wrapper = {
1239    ast.Getattr: AugGetattr,
1240    ast.Name: AugName,
1241    ast.Slice: AugSlice,
1242    ast.Subscript: AugSubscript,
1243    }
1244
1245def wrap_aug(node):
1246    return wrapper[node.__class__](node)
1247
1248if __name__ == "__main__":
1249    import sys
1250
1251    for file in sys.argv[1:]:
1252        compile(file)
1253