1import os
2
3from Cython.Compiler import CmdLine
4from Cython.TestUtils import TransformTest
5from Cython.Compiler.ParseTreeTransforms import *
6from Cython.Compiler.Nodes import *
7from Cython.Compiler import Main, Symtab
8
9
10class TestNormalizeTree(TransformTest):
11    def test_parserbehaviour_is_what_we_coded_for(self):
12        t = self.fragment(u"if x: y").root
13        self.assertLines(u"""
14(root): StatListNode
15  stats[0]: IfStatNode
16    if_clauses[0]: IfClauseNode
17      condition: NameNode
18      body: ExprStatNode
19        expr: NameNode
20""", self.treetypes(t))
21
22    def test_wrap_singlestat(self):
23        t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
24        self.assertLines(u"""
25(root): StatListNode
26  stats[0]: IfStatNode
27    if_clauses[0]: IfClauseNode
28      condition: NameNode
29      body: StatListNode
30        stats[0]: ExprStatNode
31          expr: NameNode
32""", self.treetypes(t))
33
34    def test_wrap_multistat(self):
35        t = self.run_pipeline([NormalizeTree(None)], u"""
36            if z:
37                x
38                y
39        """)
40        self.assertLines(u"""
41(root): StatListNode
42  stats[0]: IfStatNode
43    if_clauses[0]: IfClauseNode
44      condition: NameNode
45      body: StatListNode
46        stats[0]: ExprStatNode
47          expr: NameNode
48        stats[1]: ExprStatNode
49          expr: NameNode
50""", self.treetypes(t))
51
52    def test_statinexpr(self):
53        t = self.run_pipeline([NormalizeTree(None)], u"""
54            a, b = x, y
55        """)
56        self.assertLines(u"""
57(root): StatListNode
58  stats[0]: SingleAssignmentNode
59    lhs: TupleNode
60      args[0]: NameNode
61      args[1]: NameNode
62    rhs: TupleNode
63      args[0]: NameNode
64      args[1]: NameNode
65""", self.treetypes(t))
66
67    def test_wrap_offagain(self):
68        t = self.run_pipeline([NormalizeTree(None)], u"""
69            x
70            y
71            if z:
72                x
73        """)
74        self.assertLines(u"""
75(root): StatListNode
76  stats[0]: ExprStatNode
77    expr: NameNode
78  stats[1]: ExprStatNode
79    expr: NameNode
80  stats[2]: IfStatNode
81    if_clauses[0]: IfClauseNode
82      condition: NameNode
83      body: StatListNode
84        stats[0]: ExprStatNode
85          expr: NameNode
86""", self.treetypes(t))
87
88
89    def test_pass_eliminated(self):
90        t = self.run_pipeline([NormalizeTree(None)], u"pass")
91        self.assert_(len(t.stats) == 0)
92
93class TestWithTransform(object): # (TransformTest): # Disabled!
94
95    def test_simplified(self):
96        t = self.run_pipeline([WithTransform(None)], u"""
97        with x:
98            y = z ** 3
99        """)
100
101        self.assertCode(u"""
102
103        $0_0 = x
104        $0_2 = $0_0.__exit__
105        $0_0.__enter__()
106        $0_1 = True
107        try:
108            try:
109                $1_0 = None
110                y = z ** 3
111            except:
112                $0_1 = False
113                if (not $0_2($1_0)):
114                    raise
115        finally:
116            if $0_1:
117                $0_2(None, None, None)
118
119        """, t)
120
121    def test_basic(self):
122        t = self.run_pipeline([WithTransform(None)], u"""
123        with x as y:
124            y = z ** 3
125        """)
126        self.assertCode(u"""
127
128        $0_0 = x
129        $0_2 = $0_0.__exit__
130        $0_3 = $0_0.__enter__()
131        $0_1 = True
132        try:
133            try:
134                $1_0 = None
135                y = $0_3
136                y = z ** 3
137            except:
138                $0_1 = False
139                if (not $0_2($1_0)):
140                    raise
141        finally:
142            if $0_1:
143                $0_2(None, None, None)
144
145        """, t)
146
147
148class TestInterpretCompilerDirectives(TransformTest):
149    """
150    This class tests the parallel directives AST-rewriting and importing.
151    """
152
153    # Test the parallel directives (c)importing
154
155    import_code = u"""
156        cimport cython.parallel
157        cimport cython.parallel as par
158        from cython cimport parallel as par2
159        from cython cimport parallel
160
161        from cython.parallel cimport threadid as tid
162        from cython.parallel cimport threadavailable as tavail
163        from cython.parallel cimport prange
164    """
165
166    expected_directives_dict = {
167        u'cython.parallel': u'cython.parallel',
168        u'par': u'cython.parallel',
169        u'par2': u'cython.parallel',
170        u'parallel': u'cython.parallel',
171
172        u"tid": u"cython.parallel.threadid",
173        u"tavail": u"cython.parallel.threadavailable",
174        u"prange": u"cython.parallel.prange",
175    }
176
177
178    def setUp(self):
179        super(TestInterpretCompilerDirectives, self).setUp()
180
181        compilation_options = Main.CompilationOptions(Main.default_options)
182        ctx = compilation_options.create_context()
183
184        transform = InterpretCompilerDirectives(ctx, ctx.compiler_directives)
185        transform.module_scope = Symtab.ModuleScope('__main__', None, ctx)
186        self.pipeline = [transform]
187
188        self.debug_exception_on_error = DebugFlags.debug_exception_on_error
189
190    def tearDown(self):
191        DebugFlags.debug_exception_on_error = self.debug_exception_on_error
192
193    def test_parallel_directives_cimports(self):
194        self.run_pipeline(self.pipeline, self.import_code)
195        parallel_directives = self.pipeline[0].parallel_directives
196        self.assertEqual(parallel_directives, self.expected_directives_dict)
197
198    def test_parallel_directives_imports(self):
199        self.run_pipeline(self.pipeline,
200                          self.import_code.replace(u'cimport', u'import'))
201        parallel_directives = self.pipeline[0].parallel_directives
202        self.assertEqual(parallel_directives, self.expected_directives_dict)
203
204
205# TODO: Re-enable once they're more robust.
206if sys.version_info[:2] >= (2, 5) and False:
207    from Cython.Debugger import DebugWriter
208    from Cython.Debugger.Tests.TestLibCython import DebuggerTestCase
209else:
210    # skip test, don't let it inherit unittest.TestCase
211    DebuggerTestCase = object
212
213class TestDebugTransform(DebuggerTestCase):
214
215    def elem_hasattrs(self, elem, attrs):
216        # we shall supporteth python 2.3 !
217        return all([attr in elem.attrib for attr in attrs])
218
219    def test_debug_info(self):
220        try:
221            assert os.path.exists(self.debug_dest)
222
223            t = DebugWriter.etree.parse(self.debug_dest)
224            # the xpath of the standard ElementTree is primitive, don't use
225            # anything fancy
226            L = list(t.find('/Module/Globals'))
227            # assertTrue is retarded, use the normal assert statement
228            assert L
229            xml_globals = dict(
230                            [(e.attrib['name'], e.attrib['type']) for e in L])
231            self.assertEqual(len(L), len(xml_globals))
232
233            L = list(t.find('/Module/Functions'))
234            assert L
235            xml_funcs = dict([(e.attrib['qualified_name'], e) for e in L])
236            self.assertEqual(len(L), len(xml_funcs))
237
238            # test globals
239            self.assertEqual('CObject', xml_globals.get('c_var'))
240            self.assertEqual('PythonObject', xml_globals.get('python_var'))
241
242            # test functions
243            funcnames = ('codefile.spam', 'codefile.ham', 'codefile.eggs',
244                         'codefile.closure', 'codefile.inner')
245            required_xml_attrs = 'name', 'cname', 'qualified_name'
246            assert all([f in xml_funcs for f in funcnames])
247            spam, ham, eggs = [xml_funcs[funcname] for funcname in funcnames]
248
249            self.assertEqual(spam.attrib['name'], 'spam')
250            self.assertNotEqual('spam', spam.attrib['cname'])
251            assert self.elem_hasattrs(spam, required_xml_attrs)
252
253            # test locals of functions
254            spam_locals = list(spam.find('Locals'))
255            assert spam_locals
256            spam_locals.sort(key=lambda e: e.attrib['name'])
257            names = [e.attrib['name'] for e in spam_locals]
258            self.assertEqual(list('abcd'), names)
259            assert self.elem_hasattrs(spam_locals[0], required_xml_attrs)
260
261            # test arguments of functions
262            spam_arguments = list(spam.find('Arguments'))
263            assert spam_arguments
264            self.assertEqual(1, len(list(spam_arguments)))
265
266            # test step-into functions
267            step_into = spam.find('StepIntoFunctions')
268            spam_stepinto = [x.attrib['name'] for x in step_into]
269            assert spam_stepinto
270            self.assertEqual(2, len(spam_stepinto))
271            assert 'puts' in spam_stepinto
272            assert 'some_c_function' in spam_stepinto
273        except:
274            f = open(self.debug_dest)
275            try:
276                print(f.read())
277            finally:
278                f.close()
279            raise
280
281
282
283if __name__ == "__main__":
284    import unittest
285    unittest.main()
286