1# -*- coding: utf-8 -*-
2
3import os
4import unittest
5from StringIO import StringIO
6
7from antlr3.tree import CommonTreeAdaptor, CommonTree, INVALID_TOKEN_TYPE
8from antlr3.treewizard import TreeWizard, computeTokenTypes, \
9     TreePatternLexer, EOF, ID, BEGIN, END, PERCENT, COLON, DOT, ARG, \
10     TreePatternParser, \
11     TreePattern, WildcardTreePattern, TreePatternTreeAdaptor
12
13
14class TestComputeTokenTypes(unittest.TestCase):
15    """Test case for the computeTokenTypes function."""
16
17    def testNone(self):
18        """computeTokenTypes(None) -> {}"""
19
20        typeMap = computeTokenTypes(None)
21        self.failUnless(isinstance(typeMap, dict))
22        self.failUnlessEqual(typeMap, {})
23
24
25    def testList(self):
26        """computeTokenTypes(['a', 'b']) -> { 'a': 0, 'b': 1 }"""
27
28        typeMap = computeTokenTypes(['a', 'b'])
29        self.failUnless(isinstance(typeMap, dict))
30        self.failUnlessEqual(typeMap, { 'a': 0, 'b': 1 })
31
32
33class TestTreePatternLexer(unittest.TestCase):
34    """Test case for the TreePatternLexer class."""
35
36    def testBegin(self):
37        """TreePatternLexer(): '('"""
38
39        lexer = TreePatternLexer('(')
40        type = lexer.nextToken()
41        self.failUnlessEqual(type, BEGIN)
42        self.failUnlessEqual(lexer.sval, '')
43        self.failUnlessEqual(lexer.error, False)
44
45
46    def testEnd(self):
47        """TreePatternLexer(): ')'"""
48
49        lexer = TreePatternLexer(')')
50        type = lexer.nextToken()
51        self.failUnlessEqual(type, END)
52        self.failUnlessEqual(lexer.sval, '')
53        self.failUnlessEqual(lexer.error, False)
54
55
56    def testPercent(self):
57        """TreePatternLexer(): '%'"""
58
59        lexer = TreePatternLexer('%')
60        type = lexer.nextToken()
61        self.failUnlessEqual(type, PERCENT)
62        self.failUnlessEqual(lexer.sval, '')
63        self.failUnlessEqual(lexer.error, False)
64
65
66    def testDot(self):
67        """TreePatternLexer(): '.'"""
68
69        lexer = TreePatternLexer('.')
70        type = lexer.nextToken()
71        self.failUnlessEqual(type, DOT)
72        self.failUnlessEqual(lexer.sval, '')
73        self.failUnlessEqual(lexer.error, False)
74
75
76    def testColon(self):
77        """TreePatternLexer(): ':'"""
78
79        lexer = TreePatternLexer(':')
80        type = lexer.nextToken()
81        self.failUnlessEqual(type, COLON)
82        self.failUnlessEqual(lexer.sval, '')
83        self.failUnlessEqual(lexer.error, False)
84
85
86    def testEOF(self):
87        """TreePatternLexer(): EOF"""
88
89        lexer = TreePatternLexer('  \n \r \t ')
90        type = lexer.nextToken()
91        self.failUnlessEqual(type, EOF)
92        self.failUnlessEqual(lexer.sval, '')
93        self.failUnlessEqual(lexer.error, False)
94
95
96    def testID(self):
97        """TreePatternLexer(): ID"""
98
99        lexer = TreePatternLexer('_foo12_bar')
100        type = lexer.nextToken()
101        self.failUnlessEqual(type, ID)
102        self.failUnlessEqual(lexer.sval, '_foo12_bar')
103        self.failUnlessEqual(lexer.error, False)
104
105
106    def testARG(self):
107        """TreePatternLexer(): ARG"""
108
109        lexer = TreePatternLexer('[ \\]bla\\n]')
110        type = lexer.nextToken()
111        self.failUnlessEqual(type, ARG)
112        self.failUnlessEqual(lexer.sval, ' ]bla\\n')
113        self.failUnlessEqual(lexer.error, False)
114
115
116    def testError(self):
117        """TreePatternLexer(): error"""
118
119        lexer = TreePatternLexer('1')
120        type = lexer.nextToken()
121        self.failUnlessEqual(type, EOF)
122        self.failUnlessEqual(lexer.sval, '')
123        self.failUnlessEqual(lexer.error, True)
124
125
126class TestTreePatternParser(unittest.TestCase):
127    """Test case for the TreePatternParser class."""
128
129    def setUp(self):
130        """Setup text fixure
131
132        We need a tree adaptor, use CommonTreeAdaptor.
133        And a constant list of token names.
134
135        """
136
137        self.adaptor = CommonTreeAdaptor()
138        self.tokens = [
139            "", "", "", "", "", "A", "B", "C", "D", "E", "ID", "VAR"
140            ]
141        self.wizard = TreeWizard(self.adaptor, tokenNames=self.tokens)
142
143
144    def testSingleNode(self):
145        """TreePatternParser: 'ID'"""
146        lexer = TreePatternLexer('ID')
147        parser = TreePatternParser(lexer, self.wizard, self.adaptor)
148        tree = parser.pattern()
149        self.failUnless(isinstance(tree, CommonTree))
150        self.failUnlessEqual(tree.getType(), 10)
151        self.failUnlessEqual(tree.getText(), 'ID')
152
153
154    def testSingleNodeWithArg(self):
155        """TreePatternParser: 'ID[foo]'"""
156        lexer = TreePatternLexer('ID[foo]')
157        parser = TreePatternParser(lexer, self.wizard, self.adaptor)
158        tree = parser.pattern()
159        self.failUnless(isinstance(tree, CommonTree))
160        self.failUnlessEqual(tree.getType(), 10)
161        self.failUnlessEqual(tree.getText(), 'foo')
162
163
164    def testSingleLevelTree(self):
165        """TreePatternParser: '(A B)'"""
166        lexer = TreePatternLexer('(A B)')
167        parser = TreePatternParser(lexer, self.wizard, self.adaptor)
168        tree = parser.pattern()
169        self.failUnless(isinstance(tree, CommonTree))
170        self.failUnlessEqual(tree.getType(), 5)
171        self.failUnlessEqual(tree.getText(), 'A')
172        self.failUnlessEqual(tree.getChildCount(), 1)
173        self.failUnlessEqual(tree.getChild(0).getType(), 6)
174        self.failUnlessEqual(tree.getChild(0).getText(), 'B')
175
176
177    def testNil(self):
178        """TreePatternParser: 'nil'"""
179        lexer = TreePatternLexer('nil')
180        parser = TreePatternParser(lexer, self.wizard, self.adaptor)
181        tree = parser.pattern()
182        self.failUnless(isinstance(tree, CommonTree))
183        self.failUnlessEqual(tree.getType(), 0)
184        self.failUnlessEqual(tree.getText(), None)
185
186
187    def testWildcard(self):
188        """TreePatternParser: '(.)'"""
189        lexer = TreePatternLexer('(.)')
190        parser = TreePatternParser(lexer, self.wizard, self.adaptor)
191        tree = parser.pattern()
192        self.failUnless(isinstance(tree, WildcardTreePattern))
193
194
195    def testLabel(self):
196        """TreePatternParser: '(%a:A)'"""
197        lexer = TreePatternLexer('(%a:A)')
198        parser = TreePatternParser(lexer, self.wizard, TreePatternTreeAdaptor())
199        tree = parser.pattern()
200        self.failUnless(isinstance(tree, TreePattern))
201        self.failUnlessEqual(tree.label, 'a')
202
203
204    def testError1(self):
205        """TreePatternParser: ')'"""
206        lexer = TreePatternLexer(')')
207        parser = TreePatternParser(lexer, self.wizard, self.adaptor)
208        tree = parser.pattern()
209        self.failUnless(tree is None)
210
211
212    def testError2(self):
213        """TreePatternParser: '()'"""
214        lexer = TreePatternLexer('()')
215        parser = TreePatternParser(lexer, self.wizard, self.adaptor)
216        tree = parser.pattern()
217        self.failUnless(tree is None)
218
219
220    def testError3(self):
221        """TreePatternParser: '(A ])'"""
222        lexer = TreePatternLexer('(A ])')
223        parser = TreePatternParser(lexer, self.wizard, self.adaptor)
224        tree = parser.pattern()
225        self.failUnless(tree is None)
226
227
228class TestTreeWizard(unittest.TestCase):
229    """Test case for the TreeWizard class."""
230
231    def setUp(self):
232        """Setup text fixure
233
234        We need a tree adaptor, use CommonTreeAdaptor.
235        And a constant list of token names.
236
237        """
238
239        self.adaptor = CommonTreeAdaptor()
240        self.tokens = [
241            "", "", "", "", "", "A", "B", "C", "D", "E", "ID", "VAR"
242            ]
243
244
245    def testInit(self):
246        """TreeWizard.__init__()"""
247
248        wiz = TreeWizard(
249            self.adaptor,
250            tokenNames=['a', 'b']
251            )
252
253        self.failUnless(wiz.adaptor is self.adaptor)
254        self.failUnlessEqual(
255            wiz.tokenNameToTypeMap,
256            { 'a': 0, 'b': 1 }
257            )
258
259
260    def testGetTokenType(self):
261        """TreeWizard.getTokenType()"""
262
263        wiz = TreeWizard(
264            self.adaptor,
265            tokenNames=self.tokens
266            )
267
268        self.failUnlessEqual(
269            wiz.getTokenType('A'),
270            5
271            )
272
273        self.failUnlessEqual(
274            wiz.getTokenType('VAR'),
275            11
276            )
277
278        self.failUnlessEqual(
279            wiz.getTokenType('invalid'),
280            INVALID_TOKEN_TYPE
281            )
282
283    def testSingleNode(self):
284        wiz = TreeWizard(self.adaptor, self.tokens)
285        t = wiz.create("ID")
286        found = t.toStringTree()
287        expecting = "ID"
288        self.failUnlessEqual(expecting, found)
289
290
291    def testSingleNodeWithArg(self):
292        wiz = TreeWizard(self.adaptor, self.tokens)
293        t = wiz.create("ID[foo]")
294        found = t.toStringTree()
295        expecting = "foo"
296        self.failUnlessEqual(expecting, found)
297
298
299    def testSingleNodeTree(self):
300        wiz = TreeWizard(self.adaptor, self.tokens)
301        t = wiz.create("(A)")
302        found = t.toStringTree()
303        expecting = "A"
304        self.failUnlessEqual(expecting, found)
305
306
307    def testSingleLevelTree(self):
308        wiz = TreeWizard(self.adaptor, self.tokens)
309        t = wiz.create("(A B C D)")
310        found = t.toStringTree()
311        expecting = "(A B C D)"
312        self.failUnlessEqual(expecting, found)
313
314
315    def testListTree(self):
316        wiz = TreeWizard(self.adaptor, self.tokens)
317        t = wiz.create("(nil A B C)")
318        found = t.toStringTree()
319        expecting = "A B C"
320        self.failUnlessEqual(expecting, found)
321
322
323    def testInvalidListTree(self):
324        wiz = TreeWizard(self.adaptor, self.tokens)
325        t = wiz.create("A B C")
326        self.failUnless(t is None)
327
328
329    def testDoubleLevelTree(self):
330        wiz = TreeWizard(self.adaptor, self.tokens)
331        t = wiz.create("(A (B C) (B D) E)")
332        found = t.toStringTree()
333        expecting = "(A (B C) (B D) E)"
334        self.failUnlessEqual(expecting, found)
335
336
337    def __simplifyIndexMap(self, indexMap):
338        return dict( # stringify nodes for easy comparing
339            (ttype, [str(node) for node in nodes])
340            for ttype, nodes in indexMap.items()
341            )
342
343    def testSingleNodeIndex(self):
344        wiz = TreeWizard(self.adaptor, self.tokens)
345        tree = wiz.create("ID")
346        indexMap = wiz.index(tree)
347        found = self.__simplifyIndexMap(indexMap)
348        expecting = { 10: ["ID"] }
349        self.failUnlessEqual(expecting, found)
350
351
352    def testNoRepeatsIndex(self):
353        wiz = TreeWizard(self.adaptor, self.tokens)
354        tree = wiz.create("(A B C D)")
355        indexMap = wiz.index(tree)
356        found = self.__simplifyIndexMap(indexMap)
357        expecting = { 8:['D'], 6:['B'], 7:['C'], 5:['A'] }
358        self.failUnlessEqual(expecting, found)
359
360
361    def testRepeatsIndex(self):
362        wiz = TreeWizard(self.adaptor, self.tokens)
363        tree = wiz.create("(A B (A C B) B D D)")
364        indexMap = wiz.index(tree)
365        found = self.__simplifyIndexMap(indexMap)
366        expecting = { 8: ['D', 'D'], 6: ['B', 'B', 'B'], 7: ['C'], 5: ['A', 'A'] }
367        self.failUnlessEqual(expecting, found)
368
369
370    def testNoRepeatsVisit(self):
371        wiz = TreeWizard(self.adaptor, self.tokens)
372        tree = wiz.create("(A B C D)")
373
374        elements = []
375        def visitor(node, parent, childIndex, labels):
376            elements.append(str(node))
377
378        wiz.visit(tree, wiz.getTokenType("B"), visitor)
379
380        expecting = ['B']
381        self.failUnlessEqual(expecting, elements)
382
383
384    def testNoRepeatsVisit2(self):
385        wiz = TreeWizard(self.adaptor, self.tokens)
386        tree = wiz.create("(A B (A C B) B D D)")
387
388        elements = []
389        def visitor(node, parent, childIndex, labels):
390            elements.append(str(node))
391
392        wiz.visit(tree, wiz.getTokenType("C"), visitor)
393
394        expecting = ['C']
395        self.failUnlessEqual(expecting, elements)
396
397
398    def testRepeatsVisit(self):
399        wiz = TreeWizard(self.adaptor, self.tokens)
400        tree = wiz.create("(A B (A C B) B D D)")
401
402        elements = []
403        def visitor(node, parent, childIndex, labels):
404            elements.append(str(node))
405
406        wiz.visit(tree, wiz.getTokenType("B"), visitor)
407
408        expecting = ['B', 'B', 'B']
409        self.failUnlessEqual(expecting, elements)
410
411
412    def testRepeatsVisit2(self):
413        wiz = TreeWizard(self.adaptor, self.tokens)
414        tree = wiz.create("(A B (A C B) B D D)")
415
416        elements = []
417        def visitor(node, parent, childIndex, labels):
418            elements.append(str(node))
419
420        wiz.visit(tree, wiz.getTokenType("A"), visitor)
421
422        expecting = ['A', 'A']
423        self.failUnlessEqual(expecting, elements)
424
425
426    def testRepeatsVisitWithContext(self):
427        wiz = TreeWizard(self.adaptor, self.tokens)
428        tree = wiz.create("(A B (A C B) B D D)")
429
430        elements = []
431        def visitor(node, parent, childIndex, labels):
432            elements.append('%s@%s[%d]' % (node, parent, childIndex))
433
434        wiz.visit(tree, wiz.getTokenType("B"), visitor)
435
436        expecting = ['B@A[0]', 'B@A[1]', 'B@A[2]']
437        self.failUnlessEqual(expecting, elements)
438
439
440    def testRepeatsVisitWithNullParentAndContext(self):
441        wiz = TreeWizard(self.adaptor, self.tokens)
442        tree = wiz.create("(A B (A C B) B D D)")
443
444        elements = []
445        def visitor(node, parent, childIndex, labels):
446            elements.append(
447                '%s@%s[%d]'
448                % (node, ['nil', parent][parent is not None], childIndex)
449                )
450
451        wiz.visit(tree, wiz.getTokenType("A"), visitor)
452
453        expecting = ['A@nil[0]', 'A@A[1]']
454        self.failUnlessEqual(expecting, elements)
455
456
457    def testVisitPattern(self):
458        wiz = TreeWizard(self.adaptor, self.tokens)
459        tree = wiz.create("(A B C (A B) D)")
460
461        elements = []
462        def visitor(node, parent, childIndex, labels):
463            elements.append(
464                str(node)
465                )
466
467        wiz.visit(tree, '(A B)', visitor)
468
469        expecting = ['A'] # shouldn't match overall root, just (A B)
470        self.failUnlessEqual(expecting, elements)
471
472
473    def testVisitPatternMultiple(self):
474        wiz = TreeWizard(self.adaptor, self.tokens)
475        tree = wiz.create("(A B C (A B) (D (A B)))")
476
477        elements = []
478        def visitor(node, parent, childIndex, labels):
479            elements.append(
480                '%s@%s[%d]'
481                % (node, ['nil', parent][parent is not None], childIndex)
482                )
483
484        wiz.visit(tree, '(A B)', visitor)
485
486        expecting = ['A@A[2]', 'A@D[0]']
487        self.failUnlessEqual(expecting, elements)
488
489
490    def testVisitPatternMultipleWithLabels(self):
491        wiz = TreeWizard(self.adaptor, self.tokens)
492        tree = wiz.create("(A B C (A[foo] B[bar]) (D (A[big] B[dog])))")
493
494        elements = []
495        def visitor(node, parent, childIndex, labels):
496            elements.append(
497                '%s@%s[%d]%s&%s'
498                % (node,
499                   ['nil', parent][parent is not None],
500                   childIndex,
501                   labels['a'],
502                   labels['b'],
503                   )
504                )
505
506        wiz.visit(tree, '(%a:A %b:B)', visitor)
507
508        expecting = ['foo@A[2]foo&bar', 'big@D[0]big&dog']
509        self.failUnlessEqual(expecting, elements)
510
511
512    def testParse(self):
513        wiz = TreeWizard(self.adaptor, self.tokens)
514        t = wiz.create("(A B C)")
515        valid = wiz.parse(t, "(A B C)")
516        self.failUnless(valid)
517
518
519    def testParseSingleNode(self):
520        wiz = TreeWizard(self.adaptor, self.tokens)
521        t = wiz.create("A")
522        valid = wiz.parse(t, "A")
523        self.failUnless(valid)
524
525
526    def testParseSingleNodeFails(self):
527        wiz = TreeWizard(self.adaptor, self.tokens)
528        t = wiz.create("A")
529        valid = wiz.parse(t, "B")
530        self.failUnless(not valid)
531
532
533    def testParseFlatTree(self):
534        wiz = TreeWizard(self.adaptor, self.tokens)
535        t = wiz.create("(nil A B C)")
536        valid = wiz.parse(t, "(nil A B C)")
537        self.failUnless(valid)
538
539
540    def testParseFlatTreeFails(self):
541        wiz = TreeWizard(self.adaptor, self.tokens)
542        t = wiz.create("(nil A B C)")
543        valid = wiz.parse(t, "(nil A B)")
544        self.failUnless(not valid)
545
546
547    def testParseFlatTreeFails2(self):
548        wiz = TreeWizard(self.adaptor, self.tokens)
549        t = wiz.create("(nil A B C)")
550        valid = wiz.parse(t, "(nil A B A)")
551        self.failUnless(not valid)
552
553
554    def testWildcard(self):
555        wiz = TreeWizard(self.adaptor, self.tokens)
556        t = wiz.create("(A B C)")
557        valid = wiz.parse(t, "(A . .)")
558        self.failUnless(valid)
559
560
561    def testParseWithText(self):
562        wiz = TreeWizard(self.adaptor, self.tokens)
563        t = wiz.create("(A B[foo] C[bar])")
564        # C pattern has no text arg so despite [bar] in t, no need
565        # to match text--check structure only.
566        valid = wiz.parse(t, "(A B[foo] C)")
567        self.failUnless(valid)
568
569
570    def testParseWithText2(self):
571        wiz = TreeWizard(self.adaptor, self.tokens)
572        t = wiz.create("(A B[T__32] (C (D E[a])))")
573        # C pattern has no text arg so despite [bar] in t, no need
574        # to match text--check structure only.
575        valid = wiz.parse(t, "(A B[foo] C)")
576        self.assertEquals("(A T__32 (C (D a)))", t.toStringTree())
577
578
579    def testParseWithTextFails(self):
580        wiz = TreeWizard(self.adaptor, self.tokens)
581        t = wiz.create("(A B C)")
582        valid = wiz.parse(t, "(A[foo] B C)")
583        self.failUnless(not valid) # fails
584
585
586    def testParseLabels(self):
587        wiz = TreeWizard(self.adaptor, self.tokens)
588        t = wiz.create("(A B C)")
589        labels = {}
590        valid = wiz.parse(t, "(%a:A %b:B %c:C)", labels)
591        self.failUnless(valid)
592        self.failUnlessEqual("A", str(labels["a"]))
593        self.failUnlessEqual("B", str(labels["b"]))
594        self.failUnlessEqual("C", str(labels["c"]))
595
596
597    def testParseWithWildcardLabels(self):
598        wiz = TreeWizard(self.adaptor, self.tokens)
599        t = wiz.create("(A B C)")
600        labels = {}
601        valid = wiz.parse(t, "(A %b:. %c:.)", labels)
602        self.failUnless(valid)
603        self.failUnlessEqual("B", str(labels["b"]))
604        self.failUnlessEqual("C", str(labels["c"]))
605
606
607    def testParseLabelsAndTestText(self):
608        wiz = TreeWizard(self.adaptor, self.tokens)
609        t = wiz.create("(A B[foo] C)")
610        labels = {}
611        valid = wiz.parse(t, "(%a:A %b:B[foo] %c:C)", labels)
612        self.failUnless(valid)
613        self.failUnlessEqual("A", str(labels["a"]))
614        self.failUnlessEqual("foo", str(labels["b"]))
615        self.failUnlessEqual("C", str(labels["c"]))
616
617
618    def testParseLabelsInNestedTree(self):
619        wiz = TreeWizard(self.adaptor, self.tokens)
620        t = wiz.create("(A (B C) (D E))")
621        labels = {}
622        valid = wiz.parse(t, "(%a:A (%b:B %c:C) (%d:D %e:E) )", labels)
623        self.failUnless(valid)
624        self.failUnlessEqual("A", str(labels["a"]))
625        self.failUnlessEqual("B", str(labels["b"]))
626        self.failUnlessEqual("C", str(labels["c"]))
627        self.failUnlessEqual("D", str(labels["d"]))
628        self.failUnlessEqual("E", str(labels["e"]))
629
630
631    def testEquals(self):
632        wiz = TreeWizard(self.adaptor, self.tokens)
633        t1 = wiz.create("(A B C)")
634        t2 = wiz.create("(A B C)")
635        same = wiz.equals(t1, t2)
636        self.failUnless(same)
637
638
639    def testEqualsWithText(self):
640        wiz = TreeWizard(self.adaptor, self.tokens)
641        t1 = wiz.create("(A B[foo] C)")
642        t2 = wiz.create("(A B[foo] C)")
643        same = wiz.equals(t1, t2)
644        self.failUnless(same)
645
646
647    def testEqualsWithMismatchedText(self):
648        wiz = TreeWizard(self.adaptor, self.tokens)
649        t1 = wiz.create("(A B[foo] C)")
650        t2 = wiz.create("(A B C)")
651        same = wiz.equals(t1, t2)
652        self.failUnless(not same)
653
654
655    def testEqualsWithMismatchedList(self):
656        wiz = TreeWizard(self.adaptor, self.tokens)
657        t1 = wiz.create("(A B C)")
658        t2 = wiz.create("(A B A)")
659        same = wiz.equals(t1, t2)
660        self.failUnless(not same)
661
662
663    def testEqualsWithMismatchedListLength(self):
664        wiz = TreeWizard(self.adaptor, self.tokens)
665        t1 = wiz.create("(A B C)")
666        t2 = wiz.create("(A B)")
667        same = wiz.equals(t1, t2)
668        self.failUnless(not same)
669
670
671    def testFindPattern(self):
672        wiz = TreeWizard(self.adaptor, self.tokens)
673        t = wiz.create("(A B C (A[foo] B[bar]) (D (A[big] B[dog])))")
674        subtrees = wiz.find(t, "(A B)")
675        found = [str(node) for node in subtrees]
676        expecting = ['foo', 'big']
677        self.failUnlessEqual(expecting, found)
678
679
680    def testFindTokenType(self):
681        wiz = TreeWizard(self.adaptor, self.tokens)
682        t = wiz.create("(A B C (A[foo] B[bar]) (D (A[big] B[dog])))")
683        subtrees = wiz.find(t, wiz.getTokenType('A'))
684        found = [str(node) for node in subtrees]
685        expecting = ['A', 'foo', 'big']
686        self.failUnlessEqual(expecting, found)
687
688
689
690if __name__ == "__main__":
691    unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))
692