1from Errors import error, message 2import ExprNodes 3import Nodes 4import Builtin 5import PyrexTypes 6from Cython import Utils 7from PyrexTypes import py_object_type, unspecified_type 8from Visitor import CythonTransform, EnvTransform 9 10 11class TypedExprNode(ExprNodes.ExprNode): 12 # Used for declaring assignments of a specified type without a known entry. 13 def __init__(self, type): 14 self.type = type 15 16object_expr = TypedExprNode(py_object_type) 17 18 19class MarkParallelAssignments(EnvTransform): 20 # Collects assignments inside parallel blocks prange, with parallel. 21 # Perhaps it's better to move it to ControlFlowAnalysis. 22 23 # tells us whether we're in a normal loop 24 in_loop = False 25 26 parallel_errors = False 27 28 def __init__(self, context): 29 # Track the parallel block scopes (with parallel, for i in prange()) 30 self.parallel_block_stack = [] 31 super(MarkParallelAssignments, self).__init__(context) 32 33 def mark_assignment(self, lhs, rhs, inplace_op=None): 34 if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)): 35 if lhs.entry is None: 36 # TODO: This shouldn't happen... 37 return 38 39 if self.parallel_block_stack: 40 parallel_node = self.parallel_block_stack[-1] 41 previous_assignment = parallel_node.assignments.get(lhs.entry) 42 43 # If there was a previous assignment to the variable, keep the 44 # previous assignment position 45 if previous_assignment: 46 pos, previous_inplace_op = previous_assignment 47 48 if (inplace_op and previous_inplace_op and 49 inplace_op != previous_inplace_op): 50 # x += y; x *= y 51 t = (inplace_op, previous_inplace_op) 52 error(lhs.pos, 53 "Reduction operator '%s' is inconsistent " 54 "with previous reduction operator '%s'" % t) 55 else: 56 pos = lhs.pos 57 58 parallel_node.assignments[lhs.entry] = (pos, inplace_op) 59 parallel_node.assigned_nodes.append(lhs) 60 61 elif isinstance(lhs, ExprNodes.SequenceNode): 62 for arg in lhs.args: 63 self.mark_assignment(arg, object_expr) 64 else: 65 # Could use this info to infer cdef class attributes... 66 pass 67 68 def visit_WithTargetAssignmentStatNode(self, node): 69 self.mark_assignment(node.lhs, node.rhs) 70 self.visitchildren(node) 71 return node 72 73 def visit_SingleAssignmentNode(self, node): 74 self.mark_assignment(node.lhs, node.rhs) 75 self.visitchildren(node) 76 return node 77 78 def visit_CascadedAssignmentNode(self, node): 79 for lhs in node.lhs_list: 80 self.mark_assignment(lhs, node.rhs) 81 self.visitchildren(node) 82 return node 83 84 def visit_InPlaceAssignmentNode(self, node): 85 self.mark_assignment(node.lhs, node.create_binop_node(), node.operator) 86 self.visitchildren(node) 87 return node 88 89 def visit_ForInStatNode(self, node): 90 # TODO: Remove redundancy with range optimization... 91 is_special = False 92 sequence = node.iterator.sequence 93 target = node.target 94 if isinstance(sequence, ExprNodes.SimpleCallNode): 95 function = sequence.function 96 if sequence.self is None and function.is_name: 97 entry = self.current_env().lookup(function.name) 98 if not entry or entry.is_builtin: 99 if function.name == 'reversed' and len(sequence.args) == 1: 100 sequence = sequence.args[0] 101 elif function.name == 'enumerate' and len(sequence.args) == 1: 102 if target.is_sequence_constructor and len(target.args) == 2: 103 iterator = sequence.args[0] 104 if iterator.is_name: 105 iterator_type = iterator.infer_type(self.current_env()) 106 if iterator_type.is_builtin_type: 107 # assume that builtin types have a length within Py_ssize_t 108 self.mark_assignment( 109 target.args[0], 110 ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX', 111 type=PyrexTypes.c_py_ssize_t_type)) 112 target = target.args[1] 113 sequence = sequence.args[0] 114 if isinstance(sequence, ExprNodes.SimpleCallNode): 115 function = sequence.function 116 if sequence.self is None and function.is_name: 117 entry = self.current_env().lookup(function.name) 118 if not entry or entry.is_builtin: 119 if function.name in ('range', 'xrange'): 120 is_special = True 121 for arg in sequence.args[:2]: 122 self.mark_assignment(target, arg) 123 if len(sequence.args) > 2: 124 self.mark_assignment( 125 target, 126 ExprNodes.binop_node(node.pos, 127 '+', 128 sequence.args[0], 129 sequence.args[2])) 130 131 if not is_special: 132 # A for-loop basically translates to subsequent calls to 133 # __getitem__(), so using an IndexNode here allows us to 134 # naturally infer the base type of pointers, C arrays, 135 # Python strings, etc., while correctly falling back to an 136 # object type when the base type cannot be handled. 137 self.mark_assignment(target, ExprNodes.IndexNode( 138 node.pos, 139 base=sequence, 140 index=ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX', 141 type=PyrexTypes.c_py_ssize_t_type))) 142 143 self.visitchildren(node) 144 return node 145 146 def visit_ForFromStatNode(self, node): 147 self.mark_assignment(node.target, node.bound1) 148 if node.step is not None: 149 self.mark_assignment(node.target, 150 ExprNodes.binop_node(node.pos, 151 '+', 152 node.bound1, 153 node.step)) 154 self.visitchildren(node) 155 return node 156 157 def visit_WhileStatNode(self, node): 158 self.visitchildren(node) 159 return node 160 161 def visit_ExceptClauseNode(self, node): 162 if node.target is not None: 163 self.mark_assignment(node.target, object_expr) 164 self.visitchildren(node) 165 return node 166 167 def visit_FromCImportStatNode(self, node): 168 pass # Can't be assigned to... 169 170 def visit_FromImportStatNode(self, node): 171 for name, target in node.items: 172 if name != "*": 173 self.mark_assignment(target, object_expr) 174 self.visitchildren(node) 175 return node 176 177 def visit_DefNode(self, node): 178 # use fake expressions with the right result type 179 if node.star_arg: 180 self.mark_assignment( 181 node.star_arg, TypedExprNode(Builtin.tuple_type)) 182 if node.starstar_arg: 183 self.mark_assignment( 184 node.starstar_arg, TypedExprNode(Builtin.dict_type)) 185 EnvTransform.visit_FuncDefNode(self, node) 186 return node 187 188 def visit_DelStatNode(self, node): 189 for arg in node.args: 190 self.mark_assignment(arg, arg) 191 self.visitchildren(node) 192 return node 193 194 def visit_ParallelStatNode(self, node): 195 if self.parallel_block_stack: 196 node.parent = self.parallel_block_stack[-1] 197 else: 198 node.parent = None 199 200 nested = False 201 if node.is_prange: 202 if not node.parent: 203 node.is_parallel = True 204 else: 205 node.is_parallel = (node.parent.is_prange or not 206 node.parent.is_parallel) 207 nested = node.parent.is_prange 208 else: 209 node.is_parallel = True 210 # Note: nested with parallel() blocks are handled by 211 # ParallelRangeTransform! 212 # nested = node.parent 213 nested = node.parent and node.parent.is_prange 214 215 self.parallel_block_stack.append(node) 216 217 nested = nested or len(self.parallel_block_stack) > 2 218 if not self.parallel_errors and nested and not node.is_prange: 219 error(node.pos, "Only prange() may be nested") 220 self.parallel_errors = True 221 222 if node.is_prange: 223 child_attrs = node.child_attrs 224 node.child_attrs = ['body', 'target', 'args'] 225 self.visitchildren(node) 226 node.child_attrs = child_attrs 227 228 self.parallel_block_stack.pop() 229 if node.else_clause: 230 node.else_clause = self.visit(node.else_clause) 231 else: 232 self.visitchildren(node) 233 self.parallel_block_stack.pop() 234 235 self.parallel_errors = False 236 return node 237 238 def visit_YieldExprNode(self, node): 239 if self.parallel_block_stack: 240 error(node.pos, "Yield not allowed in parallel sections") 241 242 return node 243 244 def visit_ReturnStatNode(self, node): 245 node.in_parallel = bool(self.parallel_block_stack) 246 return node 247 248 249class MarkOverflowingArithmetic(CythonTransform): 250 251 # It may be possible to integrate this with the above for 252 # performance improvements (though likely not worth it). 253 254 might_overflow = False 255 256 def __call__(self, root): 257 self.env_stack = [] 258 self.env = root.scope 259 return super(MarkOverflowingArithmetic, self).__call__(root) 260 261 def visit_safe_node(self, node): 262 self.might_overflow, saved = False, self.might_overflow 263 self.visitchildren(node) 264 self.might_overflow = saved 265 return node 266 267 def visit_neutral_node(self, node): 268 self.visitchildren(node) 269 return node 270 271 def visit_dangerous_node(self, node): 272 self.might_overflow, saved = True, self.might_overflow 273 self.visitchildren(node) 274 self.might_overflow = saved 275 return node 276 277 def visit_FuncDefNode(self, node): 278 self.env_stack.append(self.env) 279 self.env = node.local_scope 280 self.visit_safe_node(node) 281 self.env = self.env_stack.pop() 282 return node 283 284 def visit_NameNode(self, node): 285 if self.might_overflow: 286 entry = node.entry or self.env.lookup(node.name) 287 if entry: 288 entry.might_overflow = True 289 return node 290 291 def visit_BinopNode(self, node): 292 if node.operator in '&|^': 293 return self.visit_neutral_node(node) 294 else: 295 return self.visit_dangerous_node(node) 296 297 visit_UnopNode = visit_neutral_node 298 299 visit_UnaryMinusNode = visit_dangerous_node 300 301 visit_InPlaceAssignmentNode = visit_dangerous_node 302 303 visit_Node = visit_safe_node 304 305 def visit_assignment(self, lhs, rhs): 306 if (isinstance(rhs, ExprNodes.IntNode) 307 and isinstance(lhs, ExprNodes.NameNode) 308 and Utils.long_literal(rhs.value)): 309 entry = lhs.entry or self.env.lookup(lhs.name) 310 if entry: 311 entry.might_overflow = True 312 313 def visit_SingleAssignmentNode(self, node): 314 self.visit_assignment(node.lhs, node.rhs) 315 self.visitchildren(node) 316 return node 317 318 def visit_CascadedAssignmentNode(self, node): 319 for lhs in node.lhs_list: 320 self.visit_assignment(lhs, node.rhs) 321 self.visitchildren(node) 322 return node 323 324class PyObjectTypeInferer(object): 325 """ 326 If it's not declared, it's a PyObject. 327 """ 328 def infer_types(self, scope): 329 """ 330 Given a dict of entries, map all unspecified types to a specified type. 331 """ 332 for name, entry in scope.entries.items(): 333 if entry.type is unspecified_type: 334 entry.type = py_object_type 335 336class SimpleAssignmentTypeInferer(object): 337 """ 338 Very basic type inference. 339 340 Note: in order to support cross-closure type inference, this must be 341 applies to nested scopes in top-down order. 342 """ 343 def set_entry_type(self, entry, entry_type): 344 entry.type = entry_type 345 for e in entry.all_entries(): 346 e.type = entry_type 347 348 def infer_types(self, scope): 349 enabled = scope.directives['infer_types'] 350 verbose = scope.directives['infer_types.verbose'] 351 352 if enabled == True: 353 spanning_type = aggressive_spanning_type 354 elif enabled is None: # safe mode 355 spanning_type = safe_spanning_type 356 else: 357 for entry in scope.entries.values(): 358 if entry.type is unspecified_type: 359 self.set_entry_type(entry, py_object_type) 360 return 361 362 # Set of assignemnts 363 assignments = set([]) 364 assmts_resolved = set([]) 365 dependencies = {} 366 assmt_to_names = {} 367 368 for name, entry in scope.entries.items(): 369 for assmt in entry.cf_assignments: 370 names = assmt.type_dependencies() 371 assmt_to_names[assmt] = names 372 assmts = set() 373 for node in names: 374 assmts.update(node.cf_state) 375 dependencies[assmt] = assmts 376 if entry.type is unspecified_type: 377 assignments.update(entry.cf_assignments) 378 else: 379 assmts_resolved.update(entry.cf_assignments) 380 381 def infer_name_node_type(node): 382 types = [assmt.inferred_type for assmt in node.cf_state] 383 if not types: 384 node_type = py_object_type 385 else: 386 entry = node.entry 387 node_type = spanning_type( 388 types, entry.might_overflow, entry.pos) 389 node.inferred_type = node_type 390 391 def infer_name_node_type_partial(node): 392 types = [assmt.inferred_type for assmt in node.cf_state 393 if assmt.inferred_type is not None] 394 if not types: 395 return 396 entry = node.entry 397 return spanning_type(types, entry.might_overflow, entry.pos) 398 399 def resolve_assignments(assignments): 400 resolved = set() 401 for assmt in assignments: 402 deps = dependencies[assmt] 403 # All assignments are resolved 404 if assmts_resolved.issuperset(deps): 405 for node in assmt_to_names[assmt]: 406 infer_name_node_type(node) 407 # Resolve assmt 408 inferred_type = assmt.infer_type() 409 assmts_resolved.add(assmt) 410 resolved.add(assmt) 411 assignments.difference_update(resolved) 412 return resolved 413 414 def partial_infer(assmt): 415 partial_types = [] 416 for node in assmt_to_names[assmt]: 417 partial_type = infer_name_node_type_partial(node) 418 if partial_type is None: 419 return False 420 partial_types.append((node, partial_type)) 421 for node, partial_type in partial_types: 422 node.inferred_type = partial_type 423 assmt.infer_type() 424 return True 425 426 partial_assmts = set() 427 def resolve_partial(assignments): 428 # try to handle circular references 429 partials = set() 430 for assmt in assignments: 431 if assmt in partial_assmts: 432 continue 433 if partial_infer(assmt): 434 partials.add(assmt) 435 assmts_resolved.add(assmt) 436 partial_assmts.update(partials) 437 return partials 438 439 # Infer assignments 440 while True: 441 if not resolve_assignments(assignments): 442 if not resolve_partial(assignments): 443 break 444 inferred = set() 445 # First pass 446 for entry in scope.entries.values(): 447 if entry.type is not unspecified_type: 448 continue 449 entry_type = py_object_type 450 if assmts_resolved.issuperset(entry.cf_assignments): 451 types = [assmt.inferred_type for assmt in entry.cf_assignments] 452 if types and Utils.all(types): 453 entry_type = spanning_type( 454 types, entry.might_overflow, entry.pos) 455 inferred.add(entry) 456 self.set_entry_type(entry, entry_type) 457 458 def reinfer(): 459 dirty = False 460 for entry in inferred: 461 types = [assmt.infer_type() 462 for assmt in entry.cf_assignments] 463 new_type = spanning_type(types, entry.might_overflow, entry.pos) 464 if new_type != entry.type: 465 self.set_entry_type(entry, new_type) 466 dirty = True 467 return dirty 468 469 # types propagation 470 while reinfer(): 471 pass 472 473 if verbose: 474 for entry in inferred: 475 message(entry.pos, "inferred '%s' to be of type '%s'" % ( 476 entry.name, entry.type)) 477 478 479def find_spanning_type(type1, type2): 480 if type1 is type2: 481 result_type = type1 482 elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type: 483 # type inference can break the coercion back to a Python bool 484 # if it returns an arbitrary int type here 485 return py_object_type 486 else: 487 result_type = PyrexTypes.spanning_type(type1, type2) 488 if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type, 489 Builtin.float_type): 490 # Python's float type is just a C double, so it's safe to 491 # use the C type instead 492 return PyrexTypes.c_double_type 493 return result_type 494 495def aggressive_spanning_type(types, might_overflow, pos): 496 result_type = reduce(find_spanning_type, types) 497 if result_type.is_reference: 498 result_type = result_type.ref_base_type 499 if result_type.is_const: 500 result_type = result_type.const_base_type 501 if result_type.is_cpp_class: 502 result_type.check_nullary_constructor(pos) 503 return result_type 504 505def safe_spanning_type(types, might_overflow, pos): 506 result_type = reduce(find_spanning_type, types) 507 if result_type.is_const: 508 result_type = result_type.const_base_type 509 if result_type.is_reference: 510 result_type = result_type.ref_base_type 511 if result_type.is_cpp_class: 512 result_type.check_nullary_constructor(pos) 513 if result_type.is_pyobject: 514 # In theory, any specific Python type is always safe to 515 # infer. However, inferring str can cause some existing code 516 # to break, since we are also now much more strict about 517 # coercion from str to char *. See trac #553. 518 if result_type.name == 'str': 519 return py_object_type 520 else: 521 return result_type 522 elif result_type is PyrexTypes.c_double_type: 523 # Python's float type is just a C double, so it's safe to use 524 # the C type instead 525 return result_type 526 elif result_type is PyrexTypes.c_bint_type: 527 # find_spanning_type() only returns 'bint' for clean boolean 528 # operations without other int types, so this is safe, too 529 return result_type 530 elif result_type.is_ptr: 531 # Any pointer except (signed|unsigned|) char* can't implicitly 532 # become a PyObject, and inferring char* is now accepted, too. 533 return result_type 534 elif result_type.is_cpp_class: 535 # These can't implicitly become Python objects either. 536 return result_type 537 elif result_type.is_struct: 538 # Though we have struct -> object for some structs, this is uncommonly 539 # used, won't arise in pure Python, and there shouldn't be side 540 # effects, so I'm declaring this safe. 541 return result_type 542 # TODO: double complex should be OK as well, but we need 543 # to make sure everything is supported. 544 elif (result_type.is_int or result_type.is_enum) and not might_overflow: 545 return result_type 546 return py_object_type 547 548 549def get_type_inferer(): 550 return SimpleAssignmentTypeInferer() 551