1""" Test suite for the code in fixer_util """
2
3# Testing imports
4from . import support
5
6# Python imports
7import os.path
8
9# Local imports
10from lib2to3.pytree import Node, Leaf
11from lib2to3 import fixer_util
12from lib2to3.fixer_util import Attr, Name, Call, Comma
13from lib2to3.pgen2 import token
14
15def parse(code, strip_levels=0):
16    # The topmost node is file_input, which we don't care about.
17    # The next-topmost node is a *_stmt node, which we also don't care about
18    tree = support.parse_string(code)
19    for i in range(strip_levels):
20        tree = tree.children[0]
21    tree.parent = None
22    return tree
23
24class MacroTestCase(support.TestCase):
25    def assertStr(self, node, string):
26        if isinstance(node, (tuple, list)):
27            node = Node(fixer_util.syms.simple_stmt, node)
28        self.assertEqual(str(node), string)
29
30
31class Test_is_tuple(support.TestCase):
32    def is_tuple(self, string):
33        return fixer_util.is_tuple(parse(string, strip_levels=2))
34
35    def test_valid(self):
36        self.assertTrue(self.is_tuple("(a, b)"))
37        self.assertTrue(self.is_tuple("(a, (b, c))"))
38        self.assertTrue(self.is_tuple("((a, (b, c)),)"))
39        self.assertTrue(self.is_tuple("(a,)"))
40        self.assertTrue(self.is_tuple("()"))
41
42    def test_invalid(self):
43        self.assertFalse(self.is_tuple("(a)"))
44        self.assertFalse(self.is_tuple("('foo') % (b, c)"))
45
46
47class Test_is_list(support.TestCase):
48    def is_list(self, string):
49        return fixer_util.is_list(parse(string, strip_levels=2))
50
51    def test_valid(self):
52        self.assertTrue(self.is_list("[]"))
53        self.assertTrue(self.is_list("[a]"))
54        self.assertTrue(self.is_list("[a, b]"))
55        self.assertTrue(self.is_list("[a, [b, c]]"))
56        self.assertTrue(self.is_list("[[a, [b, c]],]"))
57
58    def test_invalid(self):
59        self.assertFalse(self.is_list("[]+[]"))
60
61
62class Test_Attr(MacroTestCase):
63    def test(self):
64        call = parse("foo()", strip_levels=2)
65
66        self.assertStr(Attr(Name("a"), Name("b")), "a.b")
67        self.assertStr(Attr(call, Name("b")), "foo().b")
68
69    def test_returns(self):
70        attr = Attr(Name("a"), Name("b"))
71        self.assertEqual(type(attr), list)
72
73
74class Test_Name(MacroTestCase):
75    def test(self):
76        self.assertStr(Name("a"), "a")
77        self.assertStr(Name("foo.foo().bar"), "foo.foo().bar")
78        self.assertStr(Name("a", prefix="b"), "ba")
79
80
81class Test_Call(MacroTestCase):
82    def _Call(self, name, args=None, prefix=None):
83        """Help the next test"""
84        children = []
85        if isinstance(args, list):
86            for arg in args:
87                children.append(arg)
88                children.append(Comma())
89            children.pop()
90        return Call(Name(name), children, prefix)
91
92    def test(self):
93        kids = [None,
94                [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 2),
95                 Leaf(token.NUMBER, 3)],
96                [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 3),
97                 Leaf(token.NUMBER, 2), Leaf(token.NUMBER, 4)],
98                [Leaf(token.STRING, "b"), Leaf(token.STRING, "j", prefix=" ")]
99                ]
100        self.assertStr(self._Call("A"), "A()")
101        self.assertStr(self._Call("b", kids[1]), "b(1,2,3)")
102        self.assertStr(self._Call("a.b().c", kids[2]), "a.b().c(1,3,2,4)")
103        self.assertStr(self._Call("d", kids[3], prefix=" "), " d(b, j)")
104
105
106class Test_does_tree_import(support.TestCase):
107    def _find_bind_rec(self, name, node):
108        # Search a tree for a binding -- used to find the starting
109        # point for these tests.
110        c = fixer_util.find_binding(name, node)
111        if c: return c
112        for child in node.children:
113            c = self._find_bind_rec(name, child)
114            if c: return c
115
116    def does_tree_import(self, package, name, string):
117        node = parse(string)
118        # Find the binding of start -- that's what we'll go from
119        node = self._find_bind_rec('start', node)
120        return fixer_util.does_tree_import(package, name, node)
121
122    def try_with(self, string):
123        failing_tests = (("a", "a", "from a import b"),
124                         ("a.d", "a", "from a.d import b"),
125                         ("d.a", "a", "from d.a import b"),
126                         (None, "a", "import b"),
127                         (None, "a", "import b, c, d"))
128        for package, name, import_ in failing_tests:
129            n = self.does_tree_import(package, name, import_ + "\n" + string)
130            self.assertFalse(n)
131            n = self.does_tree_import(package, name, string + "\n" + import_)
132            self.assertFalse(n)
133
134        passing_tests = (("a", "a", "from a import a"),
135                         ("x", "a", "from x import a"),
136                         ("x", "a", "from x import b, c, a, d"),
137                         ("x.b", "a", "from x.b import a"),
138                         ("x.b", "a", "from x.b import b, c, a, d"),
139                         (None, "a", "import a"),
140                         (None, "a", "import b, c, a, d"))
141        for package, name, import_ in passing_tests:
142            n = self.does_tree_import(package, name, import_ + "\n" + string)
143            self.assertTrue(n)
144            n = self.does_tree_import(package, name, string + "\n" + import_)
145            self.assertTrue(n)
146
147    def test_in_function(self):
148        self.try_with("def foo():\n\tbar.baz()\n\tstart=3")
149
150class Test_find_binding(support.TestCase):
151    def find_binding(self, name, string, package=None):
152        return fixer_util.find_binding(name, parse(string), package)
153
154    def test_simple_assignment(self):
155        self.assertTrue(self.find_binding("a", "a = b"))
156        self.assertTrue(self.find_binding("a", "a = [b, c, d]"))
157        self.assertTrue(self.find_binding("a", "a = foo()"))
158        self.assertTrue(self.find_binding("a", "a = foo().foo.foo[6][foo]"))
159        self.assertFalse(self.find_binding("a", "foo = a"))
160        self.assertFalse(self.find_binding("a", "foo = (a, b, c)"))
161
162    def test_tuple_assignment(self):
163        self.assertTrue(self.find_binding("a", "(a,) = b"))
164        self.assertTrue(self.find_binding("a", "(a, b, c) = [b, c, d]"))
165        self.assertTrue(self.find_binding("a", "(c, (d, a), b) = foo()"))
166        self.assertTrue(self.find_binding("a", "(a, b) = foo().foo[6][foo]"))
167        self.assertFalse(self.find_binding("a", "(foo, b) = (b, a)"))
168        self.assertFalse(self.find_binding("a", "(foo, (b, c)) = (a, b, c)"))
169
170    def test_list_assignment(self):
171        self.assertTrue(self.find_binding("a", "[a] = b"))
172        self.assertTrue(self.find_binding("a", "[a, b, c] = [b, c, d]"))
173        self.assertTrue(self.find_binding("a", "[c, [d, a], b] = foo()"))
174        self.assertTrue(self.find_binding("a", "[a, b] = foo().foo[a][foo]"))
175        self.assertFalse(self.find_binding("a", "[foo, b] = (b, a)"))
176        self.assertFalse(self.find_binding("a", "[foo, [b, c]] = (a, b, c)"))
177
178    def test_invalid_assignments(self):
179        self.assertFalse(self.find_binding("a", "foo.a = 5"))
180        self.assertFalse(self.find_binding("a", "foo[a] = 5"))
181        self.assertFalse(self.find_binding("a", "foo(a) = 5"))
182        self.assertFalse(self.find_binding("a", "foo(a, b) = 5"))
183
184    def test_simple_import(self):
185        self.assertTrue(self.find_binding("a", "import a"))
186        self.assertTrue(self.find_binding("a", "import b, c, a, d"))
187        self.assertFalse(self.find_binding("a", "import b"))
188        self.assertFalse(self.find_binding("a", "import b, c, d"))
189
190    def test_from_import(self):
191        self.assertTrue(self.find_binding("a", "from x import a"))
192        self.assertTrue(self.find_binding("a", "from a import a"))
193        self.assertTrue(self.find_binding("a", "from x import b, c, a, d"))
194        self.assertTrue(self.find_binding("a", "from x.b import a"))
195        self.assertTrue(self.find_binding("a", "from x.b import b, c, a, d"))
196        self.assertFalse(self.find_binding("a", "from a import b"))
197        self.assertFalse(self.find_binding("a", "from a.d import b"))
198        self.assertFalse(self.find_binding("a", "from d.a import b"))
199
200    def test_import_as(self):
201        self.assertTrue(self.find_binding("a", "import b as a"))
202        self.assertTrue(self.find_binding("a", "import b as a, c, a as f, d"))
203        self.assertFalse(self.find_binding("a", "import a as f"))
204        self.assertFalse(self.find_binding("a", "import b, c as f, d as e"))
205
206    def test_from_import_as(self):
207        self.assertTrue(self.find_binding("a", "from x import b as a"))
208        self.assertTrue(self.find_binding("a", "from x import g as a, d as b"))
209        self.assertTrue(self.find_binding("a", "from x.b import t as a"))
210        self.assertTrue(self.find_binding("a", "from x.b import g as a, d"))
211        self.assertFalse(self.find_binding("a", "from a import b as t"))
212        self.assertFalse(self.find_binding("a", "from a.d import b as t"))
213        self.assertFalse(self.find_binding("a", "from d.a import b as t"))
214
215    def test_simple_import_with_package(self):
216        self.assertTrue(self.find_binding("b", "import b"))
217        self.assertTrue(self.find_binding("b", "import b, c, d"))
218        self.assertFalse(self.find_binding("b", "import b", "b"))
219        self.assertFalse(self.find_binding("b", "import b, c, d", "c"))
220
221    def test_from_import_with_package(self):
222        self.assertTrue(self.find_binding("a", "from x import a", "x"))
223        self.assertTrue(self.find_binding("a", "from a import a", "a"))
224        self.assertTrue(self.find_binding("a", "from x import *", "x"))
225        self.assertTrue(self.find_binding("a", "from x import b, c, a, d", "x"))
226        self.assertTrue(self.find_binding("a", "from x.b import a", "x.b"))
227        self.assertTrue(self.find_binding("a", "from x.b import *", "x.b"))
228        self.assertTrue(self.find_binding("a", "from x.b import b, c, a, d", "x.b"))
229        self.assertFalse(self.find_binding("a", "from a import b", "a"))
230        self.assertFalse(self.find_binding("a", "from a.d import b", "a.d"))
231        self.assertFalse(self.find_binding("a", "from d.a import b", "a.d"))
232        self.assertFalse(self.find_binding("a", "from x.y import *", "a.b"))
233
234    def test_import_as_with_package(self):
235        self.assertFalse(self.find_binding("a", "import b.c as a", "b.c"))
236        self.assertFalse(self.find_binding("a", "import a as f", "f"))
237        self.assertFalse(self.find_binding("a", "import a as f", "a"))
238
239    def test_from_import_as_with_package(self):
240        # Because it would take a lot of special-case code in the fixers
241        # to deal with from foo import bar as baz, we'll simply always
242        # fail if there is an "from ... import ... as ..."
243        self.assertFalse(self.find_binding("a", "from x import b as a", "x"))
244        self.assertFalse(self.find_binding("a", "from x import g as a, d as b", "x"))
245        self.assertFalse(self.find_binding("a", "from x.b import t as a", "x.b"))
246        self.assertFalse(self.find_binding("a", "from x.b import g as a, d", "x.b"))
247        self.assertFalse(self.find_binding("a", "from a import b as t", "a"))
248        self.assertFalse(self.find_binding("a", "from a import b as t", "b"))
249        self.assertFalse(self.find_binding("a", "from a import b as t", "t"))
250
251    def test_function_def(self):
252        self.assertTrue(self.find_binding("a", "def a(): pass"))
253        self.assertTrue(self.find_binding("a", "def a(b, c, d): pass"))
254        self.assertTrue(self.find_binding("a", "def a(): b = 7"))
255        self.assertFalse(self.find_binding("a", "def d(b, (c, a), e): pass"))
256        self.assertFalse(self.find_binding("a", "def d(a=7): pass"))
257        self.assertFalse(self.find_binding("a", "def d(a): pass"))
258        self.assertFalse(self.find_binding("a", "def d(): a = 7"))
259
260        s = """
261            def d():
262                def a():
263                    pass"""
264        self.assertFalse(self.find_binding("a", s))
265
266    def test_class_def(self):
267        self.assertTrue(self.find_binding("a", "class a: pass"))
268        self.assertTrue(self.find_binding("a", "class a(): pass"))
269        self.assertTrue(self.find_binding("a", "class a(b): pass"))
270        self.assertTrue(self.find_binding("a", "class a(b, c=8): pass"))
271        self.assertFalse(self.find_binding("a", "class d: pass"))
272        self.assertFalse(self.find_binding("a", "class d(a): pass"))
273        self.assertFalse(self.find_binding("a", "class d(b, a=7): pass"))
274        self.assertFalse(self.find_binding("a", "class d(b, *a): pass"))
275        self.assertFalse(self.find_binding("a", "class d(b, **a): pass"))
276        self.assertFalse(self.find_binding("a", "class d: a = 7"))
277
278        s = """
279            class d():
280                class a():
281                    pass"""
282        self.assertFalse(self.find_binding("a", s))
283
284    def test_for(self):
285        self.assertTrue(self.find_binding("a", "for a in r: pass"))
286        self.assertTrue(self.find_binding("a", "for a, b in r: pass"))
287        self.assertTrue(self.find_binding("a", "for (a, b) in r: pass"))
288        self.assertTrue(self.find_binding("a", "for c, (a,) in r: pass"))
289        self.assertTrue(self.find_binding("a", "for c, (a, b) in r: pass"))
290        self.assertTrue(self.find_binding("a", "for c in r: a = c"))
291        self.assertFalse(self.find_binding("a", "for c in a: pass"))
292
293    def test_for_nested(self):
294        s = """
295            for b in r:
296                for a in b:
297                    pass"""
298        self.assertTrue(self.find_binding("a", s))
299
300        s = """
301            for b in r:
302                for a, c in b:
303                    pass"""
304        self.assertTrue(self.find_binding("a", s))
305
306        s = """
307            for b in r:
308                for (a, c) in b:
309                    pass"""
310        self.assertTrue(self.find_binding("a", s))
311
312        s = """
313            for b in r:
314                for (a,) in b:
315                    pass"""
316        self.assertTrue(self.find_binding("a", s))
317
318        s = """
319            for b in r:
320                for c, (a, d) in b:
321                    pass"""
322        self.assertTrue(self.find_binding("a", s))
323
324        s = """
325            for b in r:
326                for c in b:
327                    a = 7"""
328        self.assertTrue(self.find_binding("a", s))
329
330        s = """
331            for b in r:
332                for c in b:
333                    d = a"""
334        self.assertFalse(self.find_binding("a", s))
335
336        s = """
337            for b in r:
338                for c in a:
339                    d = 7"""
340        self.assertFalse(self.find_binding("a", s))
341
342    def test_if(self):
343        self.assertTrue(self.find_binding("a", "if b in r: a = c"))
344        self.assertFalse(self.find_binding("a", "if a in r: d = e"))
345
346    def test_if_nested(self):
347        s = """
348            if b in r:
349                if c in d:
350                    a = c"""
351        self.assertTrue(self.find_binding("a", s))
352
353        s = """
354            if b in r:
355                if c in d:
356                    c = a"""
357        self.assertFalse(self.find_binding("a", s))
358
359    def test_while(self):
360        self.assertTrue(self.find_binding("a", "while b in r: a = c"))
361        self.assertFalse(self.find_binding("a", "while a in r: d = e"))
362
363    def test_while_nested(self):
364        s = """
365            while b in r:
366                while c in d:
367                    a = c"""
368        self.assertTrue(self.find_binding("a", s))
369
370        s = """
371            while b in r:
372                while c in d:
373                    c = a"""
374        self.assertFalse(self.find_binding("a", s))
375
376    def test_try_except(self):
377        s = """
378            try:
379                a = 6
380            except:
381                b = 8"""
382        self.assertTrue(self.find_binding("a", s))
383
384        s = """
385            try:
386                b = 8
387            except:
388                a = 6"""
389        self.assertTrue(self.find_binding("a", s))
390
391        s = """
392            try:
393                b = 8
394            except KeyError:
395                pass
396            except:
397                a = 6"""
398        self.assertTrue(self.find_binding("a", s))
399
400        s = """
401            try:
402                b = 8
403            except:
404                b = 6"""
405        self.assertFalse(self.find_binding("a", s))
406
407    def test_try_except_nested(self):
408        s = """
409            try:
410                try:
411                    a = 6
412                except:
413                    pass
414            except:
415                b = 8"""
416        self.assertTrue(self.find_binding("a", s))
417
418        s = """
419            try:
420                b = 8
421            except:
422                try:
423                    a = 6
424                except:
425                    pass"""
426        self.assertTrue(self.find_binding("a", s))
427
428        s = """
429            try:
430                b = 8
431            except:
432                try:
433                    pass
434                except:
435                    a = 6"""
436        self.assertTrue(self.find_binding("a", s))
437
438        s = """
439            try:
440                try:
441                    b = 8
442                except KeyError:
443                    pass
444                except:
445                    a = 6
446            except:
447                pass"""
448        self.assertTrue(self.find_binding("a", s))
449
450        s = """
451            try:
452                pass
453            except:
454                try:
455                    b = 8
456                except KeyError:
457                    pass
458                except:
459                    a = 6"""
460        self.assertTrue(self.find_binding("a", s))
461
462        s = """
463            try:
464                b = 8
465            except:
466                b = 6"""
467        self.assertFalse(self.find_binding("a", s))
468
469        s = """
470            try:
471                try:
472                    b = 8
473                except:
474                    c = d
475            except:
476                try:
477                    b = 6
478                except:
479                    t = 8
480                except:
481                    o = y"""
482        self.assertFalse(self.find_binding("a", s))
483
484    def test_try_except_finally(self):
485        s = """
486            try:
487                c = 6
488            except:
489                b = 8
490            finally:
491                a = 9"""
492        self.assertTrue(self.find_binding("a", s))
493
494        s = """
495            try:
496                b = 8
497            finally:
498                a = 6"""
499        self.assertTrue(self.find_binding("a", s))
500
501        s = """
502            try:
503                b = 8
504            finally:
505                b = 6"""
506        self.assertFalse(self.find_binding("a", s))
507
508        s = """
509            try:
510                b = 8
511            except:
512                b = 9
513            finally:
514                b = 6"""
515        self.assertFalse(self.find_binding("a", s))
516
517    def test_try_except_finally_nested(self):
518        s = """
519            try:
520                c = 6
521            except:
522                b = 8
523            finally:
524                try:
525                    a = 9
526                except:
527                    b = 9
528                finally:
529                    c = 9"""
530        self.assertTrue(self.find_binding("a", s))
531
532        s = """
533            try:
534                b = 8
535            finally:
536                try:
537                    pass
538                finally:
539                    a = 6"""
540        self.assertTrue(self.find_binding("a", s))
541
542        s = """
543            try:
544                b = 8
545            finally:
546                try:
547                    b = 6
548                finally:
549                    b = 7"""
550        self.assertFalse(self.find_binding("a", s))
551
552class Test_touch_import(support.TestCase):
553
554    def test_after_docstring(self):
555        node = parse('"""foo"""\nbar()')
556        fixer_util.touch_import(None, "foo", node)
557        self.assertEqual(str(node), '"""foo"""\nimport foo\nbar()\n\n')
558
559    def test_after_imports(self):
560        node = parse('"""foo"""\nimport bar\nbar()')
561        fixer_util.touch_import(None, "foo", node)
562        self.assertEqual(str(node), '"""foo"""\nimport bar\nimport foo\nbar()\n\n')
563
564    def test_beginning(self):
565        node = parse('bar()')
566        fixer_util.touch_import(None, "foo", node)
567        self.assertEqual(str(node), 'import foo\nbar()\n\n')
568
569    def test_from_import(self):
570        node = parse('bar()')
571        fixer_util.touch_import("html", "escape", node)
572        self.assertEqual(str(node), 'from html import escape\nbar()\n\n')
573
574    def test_name_import(self):
575        node = parse('bar()')
576        fixer_util.touch_import(None, "cgi", node)
577        self.assertEqual(str(node), 'import cgi\nbar()\n\n')
578
579class Test_find_indentation(support.TestCase):
580
581    def test_nothing(self):
582        fi = fixer_util.find_indentation
583        node = parse("node()")
584        self.assertEqual(fi(node), u"")
585        node = parse("")
586        self.assertEqual(fi(node), u"")
587
588    def test_simple(self):
589        fi = fixer_util.find_indentation
590        node = parse("def f():\n    x()")
591        self.assertEqual(fi(node), u"")
592        self.assertEqual(fi(node.children[0].children[4].children[2]), u"    ")
593        node = parse("def f():\n    x()\n    y()")
594        self.assertEqual(fi(node.children[0].children[4].children[4]), u"    ")
595