1from Errors import error, message
2import ExprNodes
3import Nodes
4import Builtin
5import PyrexTypes
6from Cython import Utils
7from PyrexTypes import py_object_type, unspecified_type
8from Visitor import CythonTransform, EnvTransform
9
10
11class TypedExprNode(ExprNodes.ExprNode):
12    # Used for declaring assignments of a specified type without a known entry.
13    def __init__(self, type):
14        self.type = type
15
16object_expr = TypedExprNode(py_object_type)
17
18
19class MarkParallelAssignments(EnvTransform):
20    # Collects assignments inside parallel blocks prange, with parallel.
21    # Perhaps it's better to move it to ControlFlowAnalysis.
22
23    # tells us whether we're in a normal loop
24    in_loop = False
25
26    parallel_errors = False
27
28    def __init__(self, context):
29        # Track the parallel block scopes (with parallel, for i in prange())
30        self.parallel_block_stack = []
31        super(MarkParallelAssignments, self).__init__(context)
32
33    def mark_assignment(self, lhs, rhs, inplace_op=None):
34        if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
35            if lhs.entry is None:
36                # TODO: This shouldn't happen...
37                return
38
39            if self.parallel_block_stack:
40                parallel_node = self.parallel_block_stack[-1]
41                previous_assignment = parallel_node.assignments.get(lhs.entry)
42
43                # If there was a previous assignment to the variable, keep the
44                # previous assignment position
45                if previous_assignment:
46                    pos, previous_inplace_op = previous_assignment
47
48                    if (inplace_op and previous_inplace_op and
49                            inplace_op != previous_inplace_op):
50                        # x += y; x *= y
51                        t = (inplace_op, previous_inplace_op)
52                        error(lhs.pos,
53                              "Reduction operator '%s' is inconsistent "
54                              "with previous reduction operator '%s'" % t)
55                else:
56                    pos = lhs.pos
57
58                parallel_node.assignments[lhs.entry] = (pos, inplace_op)
59                parallel_node.assigned_nodes.append(lhs)
60
61        elif isinstance(lhs, ExprNodes.SequenceNode):
62            for arg in lhs.args:
63                self.mark_assignment(arg, object_expr)
64        else:
65            # Could use this info to infer cdef class attributes...
66            pass
67
68    def visit_WithTargetAssignmentStatNode(self, node):
69        self.mark_assignment(node.lhs, node.rhs)
70        self.visitchildren(node)
71        return node
72
73    def visit_SingleAssignmentNode(self, node):
74        self.mark_assignment(node.lhs, node.rhs)
75        self.visitchildren(node)
76        return node
77
78    def visit_CascadedAssignmentNode(self, node):
79        for lhs in node.lhs_list:
80            self.mark_assignment(lhs, node.rhs)
81        self.visitchildren(node)
82        return node
83
84    def visit_InPlaceAssignmentNode(self, node):
85        self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
86        self.visitchildren(node)
87        return node
88
89    def visit_ForInStatNode(self, node):
90        # TODO: Remove redundancy with range optimization...
91        is_special = False
92        sequence = node.iterator.sequence
93        target = node.target
94        if isinstance(sequence, ExprNodes.SimpleCallNode):
95            function = sequence.function
96            if sequence.self is None and function.is_name:
97                entry = self.current_env().lookup(function.name)
98                if not entry or entry.is_builtin:
99                    if function.name == 'reversed' and len(sequence.args) == 1:
100                        sequence = sequence.args[0]
101                    elif function.name == 'enumerate' and len(sequence.args) == 1:
102                        if target.is_sequence_constructor and len(target.args) == 2:
103                            iterator = sequence.args[0]
104                            if iterator.is_name:
105                                iterator_type = iterator.infer_type(self.current_env())
106                                if iterator_type.is_builtin_type:
107                                    # assume that builtin types have a length within Py_ssize_t
108                                    self.mark_assignment(
109                                        target.args[0],
110                                        ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
111                                                          type=PyrexTypes.c_py_ssize_t_type))
112                                    target = target.args[1]
113                                    sequence = sequence.args[0]
114        if isinstance(sequence, ExprNodes.SimpleCallNode):
115            function = sequence.function
116            if sequence.self is None and function.is_name:
117                entry = self.current_env().lookup(function.name)
118                if not entry or entry.is_builtin:
119                    if function.name in ('range', 'xrange'):
120                        is_special = True
121                        for arg in sequence.args[:2]:
122                            self.mark_assignment(target, arg)
123                        if len(sequence.args) > 2:
124                            self.mark_assignment(
125                                target,
126                                ExprNodes.binop_node(node.pos,
127                                                     '+',
128                                                     sequence.args[0],
129                                                     sequence.args[2]))
130
131        if not is_special:
132            # A for-loop basically translates to subsequent calls to
133            # __getitem__(), so using an IndexNode here allows us to
134            # naturally infer the base type of pointers, C arrays,
135            # Python strings, etc., while correctly falling back to an
136            # object type when the base type cannot be handled.
137            self.mark_assignment(target, ExprNodes.IndexNode(
138                node.pos,
139                base=sequence,
140                index=ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
141                                        type=PyrexTypes.c_py_ssize_t_type)))
142
143        self.visitchildren(node)
144        return node
145
146    def visit_ForFromStatNode(self, node):
147        self.mark_assignment(node.target, node.bound1)
148        if node.step is not None:
149            self.mark_assignment(node.target,
150                    ExprNodes.binop_node(node.pos,
151                                         '+',
152                                         node.bound1,
153                                         node.step))
154        self.visitchildren(node)
155        return node
156
157    def visit_WhileStatNode(self, node):
158        self.visitchildren(node)
159        return node
160
161    def visit_ExceptClauseNode(self, node):
162        if node.target is not None:
163            self.mark_assignment(node.target, object_expr)
164        self.visitchildren(node)
165        return node
166
167    def visit_FromCImportStatNode(self, node):
168        pass # Can't be assigned to...
169
170    def visit_FromImportStatNode(self, node):
171        for name, target in node.items:
172            if name != "*":
173                self.mark_assignment(target, object_expr)
174        self.visitchildren(node)
175        return node
176
177    def visit_DefNode(self, node):
178        # use fake expressions with the right result type
179        if node.star_arg:
180            self.mark_assignment(
181                node.star_arg, TypedExprNode(Builtin.tuple_type))
182        if node.starstar_arg:
183            self.mark_assignment(
184                node.starstar_arg, TypedExprNode(Builtin.dict_type))
185        EnvTransform.visit_FuncDefNode(self, node)
186        return node
187
188    def visit_DelStatNode(self, node):
189        for arg in node.args:
190            self.mark_assignment(arg, arg)
191        self.visitchildren(node)
192        return node
193
194    def visit_ParallelStatNode(self, node):
195        if self.parallel_block_stack:
196            node.parent = self.parallel_block_stack[-1]
197        else:
198            node.parent = None
199
200        nested = False
201        if node.is_prange:
202            if not node.parent:
203                node.is_parallel = True
204            else:
205                node.is_parallel = (node.parent.is_prange or not
206                                    node.parent.is_parallel)
207                nested = node.parent.is_prange
208        else:
209            node.is_parallel = True
210            # Note: nested with parallel() blocks are handled by
211            # ParallelRangeTransform!
212            # nested = node.parent
213            nested = node.parent and node.parent.is_prange
214
215        self.parallel_block_stack.append(node)
216
217        nested = nested or len(self.parallel_block_stack) > 2
218        if not self.parallel_errors and nested and not node.is_prange:
219            error(node.pos, "Only prange() may be nested")
220            self.parallel_errors = True
221
222        if node.is_prange:
223            child_attrs = node.child_attrs
224            node.child_attrs = ['body', 'target', 'args']
225            self.visitchildren(node)
226            node.child_attrs = child_attrs
227
228            self.parallel_block_stack.pop()
229            if node.else_clause:
230                node.else_clause = self.visit(node.else_clause)
231        else:
232            self.visitchildren(node)
233            self.parallel_block_stack.pop()
234
235        self.parallel_errors = False
236        return node
237
238    def visit_YieldExprNode(self, node):
239        if self.parallel_block_stack:
240            error(node.pos, "Yield not allowed in parallel sections")
241
242        return node
243
244    def visit_ReturnStatNode(self, node):
245        node.in_parallel = bool(self.parallel_block_stack)
246        return node
247
248
249class MarkOverflowingArithmetic(CythonTransform):
250
251    # It may be possible to integrate this with the above for
252    # performance improvements (though likely not worth it).
253
254    might_overflow = False
255
256    def __call__(self, root):
257        self.env_stack = []
258        self.env = root.scope
259        return super(MarkOverflowingArithmetic, self).__call__(root)
260
261    def visit_safe_node(self, node):
262        self.might_overflow, saved = False, self.might_overflow
263        self.visitchildren(node)
264        self.might_overflow = saved
265        return node
266
267    def visit_neutral_node(self, node):
268        self.visitchildren(node)
269        return node
270
271    def visit_dangerous_node(self, node):
272        self.might_overflow, saved = True, self.might_overflow
273        self.visitchildren(node)
274        self.might_overflow = saved
275        return node
276
277    def visit_FuncDefNode(self, node):
278        self.env_stack.append(self.env)
279        self.env = node.local_scope
280        self.visit_safe_node(node)
281        self.env = self.env_stack.pop()
282        return node
283
284    def visit_NameNode(self, node):
285        if self.might_overflow:
286            entry = node.entry or self.env.lookup(node.name)
287            if entry:
288                entry.might_overflow = True
289        return node
290
291    def visit_BinopNode(self, node):
292        if node.operator in '&|^':
293            return self.visit_neutral_node(node)
294        else:
295            return self.visit_dangerous_node(node)
296
297    visit_UnopNode = visit_neutral_node
298
299    visit_UnaryMinusNode = visit_dangerous_node
300
301    visit_InPlaceAssignmentNode = visit_dangerous_node
302
303    visit_Node = visit_safe_node
304
305    def visit_assignment(self, lhs, rhs):
306        if (isinstance(rhs, ExprNodes.IntNode)
307                and isinstance(lhs, ExprNodes.NameNode)
308                and Utils.long_literal(rhs.value)):
309            entry = lhs.entry or self.env.lookup(lhs.name)
310            if entry:
311                entry.might_overflow = True
312
313    def visit_SingleAssignmentNode(self, node):
314        self.visit_assignment(node.lhs, node.rhs)
315        self.visitchildren(node)
316        return node
317
318    def visit_CascadedAssignmentNode(self, node):
319        for lhs in node.lhs_list:
320            self.visit_assignment(lhs, node.rhs)
321        self.visitchildren(node)
322        return node
323
324class PyObjectTypeInferer(object):
325    """
326    If it's not declared, it's a PyObject.
327    """
328    def infer_types(self, scope):
329        """
330        Given a dict of entries, map all unspecified types to a specified type.
331        """
332        for name, entry in scope.entries.items():
333            if entry.type is unspecified_type:
334                entry.type = py_object_type
335
336class SimpleAssignmentTypeInferer(object):
337    """
338    Very basic type inference.
339
340    Note: in order to support cross-closure type inference, this must be
341    applies to nested scopes in top-down order.
342    """
343    def set_entry_type(self, entry, entry_type):
344        entry.type = entry_type
345        for e in entry.all_entries():
346            e.type = entry_type
347
348    def infer_types(self, scope):
349        enabled = scope.directives['infer_types']
350        verbose = scope.directives['infer_types.verbose']
351
352        if enabled == True:
353            spanning_type = aggressive_spanning_type
354        elif enabled is None: # safe mode
355            spanning_type = safe_spanning_type
356        else:
357            for entry in scope.entries.values():
358                if entry.type is unspecified_type:
359                    self.set_entry_type(entry, py_object_type)
360            return
361
362        # Set of assignemnts
363        assignments = set([])
364        assmts_resolved = set([])
365        dependencies = {}
366        assmt_to_names = {}
367
368        for name, entry in scope.entries.items():
369            for assmt in entry.cf_assignments:
370                names = assmt.type_dependencies()
371                assmt_to_names[assmt] = names
372                assmts = set()
373                for node in names:
374                    assmts.update(node.cf_state)
375                dependencies[assmt] = assmts
376            if entry.type is unspecified_type:
377                assignments.update(entry.cf_assignments)
378            else:
379                assmts_resolved.update(entry.cf_assignments)
380
381        def infer_name_node_type(node):
382            types = [assmt.inferred_type for assmt in node.cf_state]
383            if not types:
384                node_type = py_object_type
385            else:
386                entry = node.entry
387                node_type = spanning_type(
388                    types, entry.might_overflow, entry.pos)
389            node.inferred_type = node_type
390
391        def infer_name_node_type_partial(node):
392            types = [assmt.inferred_type for assmt in node.cf_state
393                     if assmt.inferred_type is not None]
394            if not types:
395                return
396            entry = node.entry
397            return spanning_type(types, entry.might_overflow, entry.pos)
398
399        def resolve_assignments(assignments):
400            resolved = set()
401            for assmt in assignments:
402                deps = dependencies[assmt]
403                # All assignments are resolved
404                if assmts_resolved.issuperset(deps):
405                    for node in assmt_to_names[assmt]:
406                        infer_name_node_type(node)
407                    # Resolve assmt
408                    inferred_type = assmt.infer_type()
409                    assmts_resolved.add(assmt)
410                    resolved.add(assmt)
411            assignments.difference_update(resolved)
412            return resolved
413
414        def partial_infer(assmt):
415            partial_types = []
416            for node in assmt_to_names[assmt]:
417                partial_type = infer_name_node_type_partial(node)
418                if partial_type is None:
419                    return False
420                partial_types.append((node, partial_type))
421            for node, partial_type in partial_types:
422                node.inferred_type = partial_type
423            assmt.infer_type()
424            return True
425
426        partial_assmts = set()
427        def resolve_partial(assignments):
428            # try to handle circular references
429            partials = set()
430            for assmt in assignments:
431                if assmt in partial_assmts:
432                    continue
433                if partial_infer(assmt):
434                    partials.add(assmt)
435                    assmts_resolved.add(assmt)
436            partial_assmts.update(partials)
437            return partials
438
439        # Infer assignments
440        while True:
441            if not resolve_assignments(assignments):
442                if not resolve_partial(assignments):
443                    break
444        inferred = set()
445        # First pass
446        for entry in scope.entries.values():
447            if entry.type is not unspecified_type:
448                continue
449            entry_type = py_object_type
450            if assmts_resolved.issuperset(entry.cf_assignments):
451                types = [assmt.inferred_type for assmt in entry.cf_assignments]
452                if types and Utils.all(types):
453                    entry_type = spanning_type(
454                        types, entry.might_overflow, entry.pos)
455                    inferred.add(entry)
456            self.set_entry_type(entry, entry_type)
457
458        def reinfer():
459            dirty = False
460            for entry in inferred:
461                types = [assmt.infer_type()
462                         for assmt in entry.cf_assignments]
463                new_type = spanning_type(types, entry.might_overflow, entry.pos)
464                if new_type != entry.type:
465                    self.set_entry_type(entry, new_type)
466                    dirty = True
467            return dirty
468
469        # types propagation
470        while reinfer():
471            pass
472
473        if verbose:
474            for entry in inferred:
475                message(entry.pos, "inferred '%s' to be of type '%s'" % (
476                    entry.name, entry.type))
477
478
479def find_spanning_type(type1, type2):
480    if type1 is type2:
481        result_type = type1
482    elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
483        # type inference can break the coercion back to a Python bool
484        # if it returns an arbitrary int type here
485        return py_object_type
486    else:
487        result_type = PyrexTypes.spanning_type(type1, type2)
488    if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
489                       Builtin.float_type):
490        # Python's float type is just a C double, so it's safe to
491        # use the C type instead
492        return PyrexTypes.c_double_type
493    return result_type
494
495def aggressive_spanning_type(types, might_overflow, pos):
496    result_type = reduce(find_spanning_type, types)
497    if result_type.is_reference:
498        result_type = result_type.ref_base_type
499    if result_type.is_const:
500        result_type = result_type.const_base_type
501    if result_type.is_cpp_class:
502        result_type.check_nullary_constructor(pos)
503    return result_type
504
505def safe_spanning_type(types, might_overflow, pos):
506    result_type = reduce(find_spanning_type, types)
507    if result_type.is_const:
508        result_type = result_type.const_base_type
509    if result_type.is_reference:
510        result_type = result_type.ref_base_type
511    if result_type.is_cpp_class:
512        result_type.check_nullary_constructor(pos)
513    if result_type.is_pyobject:
514        # In theory, any specific Python type is always safe to
515        # infer. However, inferring str can cause some existing code
516        # to break, since we are also now much more strict about
517        # coercion from str to char *. See trac #553.
518        if result_type.name == 'str':
519            return py_object_type
520        else:
521            return result_type
522    elif result_type is PyrexTypes.c_double_type:
523        # Python's float type is just a C double, so it's safe to use
524        # the C type instead
525        return result_type
526    elif result_type is PyrexTypes.c_bint_type:
527        # find_spanning_type() only returns 'bint' for clean boolean
528        # operations without other int types, so this is safe, too
529        return result_type
530    elif result_type.is_ptr:
531        # Any pointer except (signed|unsigned|) char* can't implicitly
532        # become a PyObject, and inferring char* is now accepted, too.
533        return result_type
534    elif result_type.is_cpp_class:
535        # These can't implicitly become Python objects either.
536        return result_type
537    elif result_type.is_struct:
538        # Though we have struct -> object for some structs, this is uncommonly
539        # used, won't arise in pure Python, and there shouldn't be side
540        # effects, so I'm declaring this safe.
541        return result_type
542    # TODO: double complex should be OK as well, but we need
543    # to make sure everything is supported.
544    elif (result_type.is_int or result_type.is_enum) and not might_overflow:
545        return result_type
546    return py_object_type
547
548
549def get_type_inferer():
550    return SimpleAssignmentTypeInferer()
551