1""" @package antlr3.tree
2@brief ANTLR3 runtime package, treewizard module
3
4A utility module to create ASTs at runtime.
5See <http://www.antlr.org/wiki/display/~admin/2007/07/02/Exploring+Concept+of+TreeWizard> for an overview. Note that the API of the Python implementation is slightly different.
6
7"""
8
9# begin[licence]
10#
11# [The "BSD licence"]
12# Copyright (c) 2005-2008 Terence Parr
13# All rights reserved.
14#
15# Redistribution and use in source and binary forms, with or without
16# modification, are permitted provided that the following conditions
17# are met:
18# 1. Redistributions of source code must retain the above copyright
19#    notice, this list of conditions and the following disclaimer.
20# 2. Redistributions in binary form must reproduce the above copyright
21#    notice, this list of conditions and the following disclaimer in the
22#    documentation and/or other materials provided with the distribution.
23# 3. The name of the author may not be used to endorse or promote products
24#    derived from this software without specific prior written permission.
25#
26# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
27# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
28# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
29# IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
30# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
31# NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
32# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
33# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
34# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
35# THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36#
37# end[licence]
38
39from antlr3.constants import INVALID_TOKEN_TYPE
40from antlr3.tokens import CommonToken
41from antlr3.tree import CommonTree, CommonTreeAdaptor
42
43
44def computeTokenTypes(tokenNames):
45    """
46    Compute a dict that is an inverted index of
47    tokenNames (which maps int token types to names).
48    """
49
50    if tokenNames is None:
51        return {}
52
53    return dict((name, type) for type, name in enumerate(tokenNames))
54
55
56## token types for pattern parser
57EOF = -1
58BEGIN = 1
59END = 2
60ID = 3
61ARG = 4
62PERCENT = 5
63COLON = 6
64DOT = 7
65
66class TreePatternLexer(object):
67    def __init__(self, pattern):
68        ## The tree pattern to lex like "(A B C)"
69        self.pattern = pattern
70
71	## Index into input string
72        self.p = -1
73
74	## Current char
75        self.c = None
76
77	## How long is the pattern in char?
78        self.n = len(pattern)
79
80	## Set when token type is ID or ARG
81        self.sval = None
82
83        self.error = False
84
85        self.consume()
86
87
88    __idStartChar = frozenset(
89        'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_'
90        )
91    __idChar = __idStartChar | frozenset('0123456789')
92
93    def nextToken(self):
94        self.sval = ""
95        while self.c != EOF:
96            if self.c in (' ', '\n', '\r', '\t'):
97                self.consume()
98                continue
99
100            if self.c in self.__idStartChar:
101                self.sval += self.c
102                self.consume()
103                while self.c in self.__idChar:
104                    self.sval += self.c
105                    self.consume()
106
107                return ID
108
109            if self.c == '(':
110                self.consume()
111                return BEGIN
112
113            if self.c == ')':
114                self.consume()
115                return END
116
117            if self.c == '%':
118                self.consume()
119                return PERCENT
120
121            if self.c == ':':
122                self.consume()
123                return COLON
124
125            if self.c == '.':
126                self.consume()
127                return DOT
128
129            if self.c == '[': # grab [x] as a string, returning x
130                self.consume()
131                while self.c != ']':
132                    if self.c == '\\':
133                        self.consume()
134                        if self.c != ']':
135                            self.sval += '\\'
136
137                        self.sval += self.c
138
139                    else:
140                        self.sval += self.c
141
142                    self.consume()
143
144                self.consume()
145                return ARG
146
147            self.consume()
148            self.error = True
149            return EOF
150
151        return EOF
152
153
154    def consume(self):
155        self.p += 1
156        if self.p >= self.n:
157            self.c = EOF
158
159        else:
160            self.c = self.pattern[self.p]
161
162
163class TreePatternParser(object):
164    def __init__(self, tokenizer, wizard, adaptor):
165        self.tokenizer = tokenizer
166        self.wizard = wizard
167        self.adaptor = adaptor
168        self.ttype = tokenizer.nextToken() # kickstart
169
170
171    def pattern(self):
172        if self.ttype == BEGIN:
173            return self.parseTree()
174
175        elif self.ttype == ID:
176            node = self.parseNode()
177            if self.ttype == EOF:
178                return node
179
180            return None # extra junk on end
181
182        return None
183
184
185    def parseTree(self):
186        if self.ttype != BEGIN:
187            return None
188
189        self.ttype = self.tokenizer.nextToken()
190        root = self.parseNode()
191        if root is None:
192            return None
193
194        while self.ttype in (BEGIN, ID, PERCENT, DOT):
195            if self.ttype == BEGIN:
196                subtree = self.parseTree()
197                self.adaptor.addChild(root, subtree)
198
199            else:
200                child = self.parseNode()
201                if child is None:
202                    return None
203
204                self.adaptor.addChild(root, child)
205
206        if self.ttype != END:
207            return None
208
209        self.ttype = self.tokenizer.nextToken()
210        return root
211
212
213    def parseNode(self):
214        # "%label:" prefix
215        label = None
216
217        if self.ttype == PERCENT:
218            self.ttype = self.tokenizer.nextToken()
219            if self.ttype != ID:
220                return None
221
222            label = self.tokenizer.sval
223            self.ttype = self.tokenizer.nextToken()
224            if self.ttype != COLON:
225                return None
226
227            self.ttype = self.tokenizer.nextToken() # move to ID following colon
228
229        # Wildcard?
230        if self.ttype == DOT:
231            self.ttype = self.tokenizer.nextToken()
232            wildcardPayload = CommonToken(0, ".")
233            node = WildcardTreePattern(wildcardPayload)
234            if label is not None:
235                node.label = label
236            return node
237
238        # "ID" or "ID[arg]"
239        if self.ttype != ID:
240            return None
241
242        tokenName = self.tokenizer.sval
243        self.ttype = self.tokenizer.nextToken()
244
245        if tokenName == "nil":
246            return self.adaptor.nil()
247
248        text = tokenName
249        # check for arg
250        arg = None
251        if self.ttype == ARG:
252            arg = self.tokenizer.sval
253            text = arg
254            self.ttype = self.tokenizer.nextToken()
255
256        # create node
257        treeNodeType = self.wizard.getTokenType(tokenName)
258        if treeNodeType == INVALID_TOKEN_TYPE:
259            return None
260
261        node = self.adaptor.createFromType(treeNodeType, text)
262        if label is not None and isinstance(node, TreePattern):
263            node.label = label
264
265        if arg is not None and isinstance(node, TreePattern):
266            node.hasTextArg = True
267
268        return node
269
270
271class TreePattern(CommonTree):
272    """
273    When using %label:TOKENNAME in a tree for parse(), we must
274    track the label.
275    """
276
277    def __init__(self, payload):
278        CommonTree.__init__(self, payload)
279
280        self.label = None
281        self.hasTextArg = None
282
283
284    def toString(self):
285        if self.label is not None:
286            return '%' + self.label + ':' + CommonTree.toString(self)
287
288        else:
289            return CommonTree.toString(self)
290
291
292class WildcardTreePattern(TreePattern):
293    pass
294
295
296class TreePatternTreeAdaptor(CommonTreeAdaptor):
297    """This adaptor creates TreePattern objects for use during scan()"""
298
299    def createWithPayload(self, payload):
300        return TreePattern(payload)
301
302
303class TreeWizard(object):
304    """
305    Build and navigate trees with this object.  Must know about the names
306    of tokens so you have to pass in a map or array of token names (from which
307    this class can build the map).  I.e., Token DECL means nothing unless the
308    class can translate it to a token type.
309
310    In order to create nodes and navigate, this class needs a TreeAdaptor.
311
312    This class can build a token type -> node index for repeated use or for
313    iterating over the various nodes with a particular type.
314
315    This class works in conjunction with the TreeAdaptor rather than moving
316    all this functionality into the adaptor.  An adaptor helps build and
317    navigate trees using methods.  This class helps you do it with string
318    patterns like "(A B C)".  You can create a tree from that pattern or
319    match subtrees against it.
320    """
321
322    def __init__(self, adaptor=None, tokenNames=None, typeMap=None):
323        if adaptor is None:
324            self.adaptor = CommonTreeAdaptor()
325
326        else:
327            self.adaptor = adaptor
328
329        if typeMap is None:
330            self.tokenNameToTypeMap = computeTokenTypes(tokenNames)
331
332        else:
333            if tokenNames is not None:
334                raise ValueError("Can't have both tokenNames and typeMap")
335
336            self.tokenNameToTypeMap = typeMap
337
338
339    def getTokenType(self, tokenName):
340        """Using the map of token names to token types, return the type."""
341
342        try:
343            return self.tokenNameToTypeMap[tokenName]
344        except KeyError:
345            return INVALID_TOKEN_TYPE
346
347
348    def create(self, pattern):
349        """
350        Create a tree or node from the indicated tree pattern that closely
351        follows ANTLR tree grammar tree element syntax:
352
353        (root child1 ... child2).
354
355        You can also just pass in a node: ID
356
357        Any node can have a text argument: ID[foo]
358        (notice there are no quotes around foo--it's clear it's a string).
359
360        nil is a special name meaning "give me a nil node".  Useful for
361        making lists: (nil A B C) is a list of A B C.
362        """
363
364        tokenizer = TreePatternLexer(pattern)
365        parser = TreePatternParser(tokenizer, self, self.adaptor)
366        return parser.pattern()
367
368
369    def index(self, tree):
370        """Walk the entire tree and make a node name to nodes mapping.
371
372        For now, use recursion but later nonrecursive version may be
373        more efficient.  Returns a dict int -> list where the list is
374        of your AST node type.  The int is the token type of the node.
375        """
376
377        m = {}
378        self._index(tree, m)
379        return m
380
381
382    def _index(self, t, m):
383        """Do the work for index"""
384
385        if t is None:
386            return
387
388        ttype = self.adaptor.getType(t)
389        elements = m.get(ttype)
390        if elements is None:
391            m[ttype] = elements = []
392
393        elements.append(t)
394        for i in range(self.adaptor.getChildCount(t)):
395            child = self.adaptor.getChild(t, i)
396            self._index(child, m)
397
398
399    def find(self, tree, what):
400        """Return a list of matching token.
401
402        what may either be an integer specifzing the token type to find or
403        a string with a pattern that must be matched.
404
405        """
406
407        if isinstance(what, (int, long)):
408            return self._findTokenType(tree, what)
409
410        elif isinstance(what, basestring):
411            return self._findPattern(tree, what)
412
413        else:
414            raise TypeError("'what' must be string or integer")
415
416
417    def _findTokenType(self, t, ttype):
418        """Return a List of tree nodes with token type ttype"""
419
420        nodes = []
421
422        def visitor(tree, parent, childIndex, labels):
423            nodes.append(tree)
424
425        self.visit(t, ttype, visitor)
426
427        return nodes
428
429
430    def _findPattern(self, t, pattern):
431        """Return a List of subtrees matching pattern."""
432
433        subtrees = []
434
435        # Create a TreePattern from the pattern
436        tokenizer = TreePatternLexer(pattern)
437        parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
438        tpattern = parser.pattern()
439
440        # don't allow invalid patterns
441        if (tpattern is None or tpattern.isNil()
442            or isinstance(tpattern, WildcardTreePattern)):
443            return None
444
445        rootTokenType = tpattern.getType()
446
447        def visitor(tree, parent, childIndex, label):
448            if self._parse(tree, tpattern, None):
449                subtrees.append(tree)
450
451        self.visit(t, rootTokenType, visitor)
452
453        return subtrees
454
455
456    def visit(self, tree, what, visitor):
457        """Visit every node in tree matching what, invoking the visitor.
458
459        If what is a string, it is parsed as a pattern and only matching
460        subtrees will be visited.
461        The implementation uses the root node of the pattern in combination
462        with visit(t, ttype, visitor) so nil-rooted patterns are not allowed.
463        Patterns with wildcard roots are also not allowed.
464
465        If what is an integer, it is used as a token type and visit will match
466        all nodes of that type (this is faster than the pattern match).
467        The labels arg of the visitor action method is never set (it's None)
468        since using a token type rather than a pattern doesn't let us set a
469        label.
470        """
471
472        if isinstance(what, (int, long)):
473            self._visitType(tree, None, 0, what, visitor)
474
475        elif isinstance(what, basestring):
476            self._visitPattern(tree, what, visitor)
477
478        else:
479            raise TypeError("'what' must be string or integer")
480
481
482    def _visitType(self, t, parent, childIndex, ttype, visitor):
483        """Do the recursive work for visit"""
484
485        if t is None:
486            return
487
488        if self.adaptor.getType(t) == ttype:
489            visitor(t, parent, childIndex, None)
490
491        for i in range(self.adaptor.getChildCount(t)):
492            child = self.adaptor.getChild(t, i)
493            self._visitType(child, t, i, ttype, visitor)
494
495
496    def _visitPattern(self, tree, pattern, visitor):
497        """
498        For all subtrees that match the pattern, execute the visit action.
499        """
500
501        # Create a TreePattern from the pattern
502        tokenizer = TreePatternLexer(pattern)
503        parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
504        tpattern = parser.pattern()
505
506        # don't allow invalid patterns
507        if (tpattern is None or tpattern.isNil()
508            or isinstance(tpattern, WildcardTreePattern)):
509            return
510
511        rootTokenType = tpattern.getType()
512
513        def rootvisitor(tree, parent, childIndex, labels):
514            labels = {}
515            if self._parse(tree, tpattern, labels):
516                visitor(tree, parent, childIndex, labels)
517
518        self.visit(tree, rootTokenType, rootvisitor)
519
520
521    def parse(self, t, pattern, labels=None):
522        """
523        Given a pattern like (ASSIGN %lhs:ID %rhs:.) with optional labels
524        on the various nodes and '.' (dot) as the node/subtree wildcard,
525        return true if the pattern matches and fill the labels Map with
526        the labels pointing at the appropriate nodes.  Return false if
527        the pattern is malformed or the tree does not match.
528
529        If a node specifies a text arg in pattern, then that must match
530        for that node in t.
531        """
532
533        tokenizer = TreePatternLexer(pattern)
534        parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
535        tpattern = parser.pattern()
536
537        return self._parse(t, tpattern, labels)
538
539
540    def _parse(self, t1, tpattern, labels):
541        """
542        Do the work for parse. Check to see if the tpattern fits the
543        structure and token types in t1.  Check text if the pattern has
544        text arguments on nodes.  Fill labels map with pointers to nodes
545        in tree matched against nodes in pattern with labels.
546	"""
547
548        # make sure both are non-null
549        if t1 is None or tpattern is None:
550            return False
551
552        # check roots (wildcard matches anything)
553        if not isinstance(tpattern, WildcardTreePattern):
554            if self.adaptor.getType(t1) != tpattern.getType():
555                return False
556
557            # if pattern has text, check node text
558            if (tpattern.hasTextArg
559                and self.adaptor.getText(t1) != tpattern.getText()):
560                return False
561
562        if tpattern.label is not None and labels is not None:
563            # map label in pattern to node in t1
564            labels[tpattern.label] = t1
565
566        # check children
567        n1 = self.adaptor.getChildCount(t1)
568        n2 = tpattern.getChildCount()
569        if n1 != n2:
570            return False
571
572        for i in range(n1):
573            child1 = self.adaptor.getChild(t1, i)
574            child2 = tpattern.getChild(i)
575            if not self._parse(child1, child2, labels):
576                return False
577
578        return True
579
580
581    def equals(self, t1, t2, adaptor=None):
582        """
583        Compare t1 and t2; return true if token types/text, structure match
584        exactly.
585        The trees are examined in their entirety so that (A B) does not match
586        (A B C) nor (A (B C)).
587        """
588
589        if adaptor is None:
590            adaptor = self.adaptor
591
592        return self._equals(t1, t2, adaptor)
593
594
595    def _equals(self, t1, t2, adaptor):
596        # make sure both are non-null
597        if t1 is None or t2 is None:
598            return False
599
600        # check roots
601        if adaptor.getType(t1) != adaptor.getType(t2):
602            return False
603
604        if adaptor.getText(t1) != adaptor.getText(t2):
605            return False
606
607        # check children
608        n1 = adaptor.getChildCount(t1)
609        n2 = adaptor.getChildCount(t2)
610        if n1 != n2:
611            return False
612
613        for i in range(n1):
614            child1 = adaptor.getChild(t1, i)
615            child2 = adaptor.getChild(t2, i)
616            if not self._equals(child1, child2, adaptor):
617                return False
618
619        return True
620