1import copy 2 3from Cython.Compiler import (ExprNodes, PyrexTypes, MemoryView, 4 ParseTreeTransforms, StringEncoding, 5 Errors) 6from Cython.Compiler.ExprNodes import CloneNode, ProxyNode, TupleNode 7from Cython.Compiler.Nodes import (FuncDefNode, CFuncDefNode, StatListNode, 8 DefNode) 9 10class FusedCFuncDefNode(StatListNode): 11 """ 12 This node replaces a function with fused arguments. It deep-copies the 13 function for every permutation of fused types, and allocates a new local 14 scope for it. It keeps track of the original function in self.node, and 15 the entry of the original function in the symbol table is given the 16 'fused_cfunction' attribute which points back to us. 17 Then when a function lookup occurs (to e.g. call it), the call can be 18 dispatched to the right function. 19 20 node FuncDefNode the original function 21 nodes [FuncDefNode] list of copies of node with different specific types 22 py_func DefNode the fused python function subscriptable from 23 Python space 24 __signatures__ A DictNode mapping signature specialization strings 25 to PyCFunction nodes 26 resulting_fused_function PyCFunction for the fused DefNode that delegates 27 to specializations 28 fused_func_assignment Assignment of the fused function to the function name 29 defaults_tuple TupleNode of defaults (letting PyCFunctionNode build 30 defaults would result in many different tuples) 31 specialized_pycfuncs List of synthesized pycfunction nodes for the 32 specializations 33 code_object CodeObjectNode shared by all specializations and the 34 fused function 35 36 fused_compound_types All fused (compound) types (e.g. floating[:]) 37 """ 38 39 __signatures__ = None 40 resulting_fused_function = None 41 fused_func_assignment = None 42 defaults_tuple = None 43 decorators = None 44 45 child_attrs = StatListNode.child_attrs + [ 46 '__signatures__', 'resulting_fused_function', 'fused_func_assignment'] 47 48 def __init__(self, node, env): 49 super(FusedCFuncDefNode, self).__init__(node.pos) 50 51 self.nodes = [] 52 self.node = node 53 54 is_def = isinstance(self.node, DefNode) 55 if is_def: 56 # self.node.decorators = [] 57 self.copy_def(env) 58 else: 59 self.copy_cdef(env) 60 61 # Perform some sanity checks. If anything fails, it's a bug 62 for n in self.nodes: 63 assert not n.entry.type.is_fused 64 assert not n.local_scope.return_type.is_fused 65 if node.return_type.is_fused: 66 assert not n.return_type.is_fused 67 68 if not is_def and n.cfunc_declarator.optional_arg_count: 69 assert n.type.op_arg_struct 70 71 node.entry.fused_cfunction = self 72 # Copy the nodes as AnalyseDeclarationsTransform will prepend 73 # self.py_func to self.stats, as we only want specialized 74 # CFuncDefNodes in self.nodes 75 self.stats = self.nodes[:] 76 77 def copy_def(self, env): 78 """ 79 Create a copy of the original def or lambda function for specialized 80 versions. 81 """ 82 fused_compound_types = PyrexTypes.unique( 83 [arg.type for arg in self.node.args if arg.type.is_fused]) 84 permutations = PyrexTypes.get_all_specialized_permutations(fused_compound_types) 85 86 self.fused_compound_types = fused_compound_types 87 88 if self.node.entry in env.pyfunc_entries: 89 env.pyfunc_entries.remove(self.node.entry) 90 91 for cname, fused_to_specific in permutations: 92 copied_node = copy.deepcopy(self.node) 93 94 self._specialize_function_args(copied_node.args, fused_to_specific) 95 copied_node.return_type = self.node.return_type.specialize( 96 fused_to_specific) 97 98 copied_node.analyse_declarations(env) 99 # copied_node.is_staticmethod = self.node.is_staticmethod 100 # copied_node.is_classmethod = self.node.is_classmethod 101 self.create_new_local_scope(copied_node, env, fused_to_specific) 102 self.specialize_copied_def(copied_node, cname, self.node.entry, 103 fused_to_specific, fused_compound_types) 104 105 PyrexTypes.specialize_entry(copied_node.entry, cname) 106 copied_node.entry.used = True 107 env.entries[copied_node.entry.name] = copied_node.entry 108 109 if not self.replace_fused_typechecks(copied_node): 110 break 111 112 self.orig_py_func = self.node 113 self.py_func = self.make_fused_cpdef(self.node, env, is_def=True) 114 115 def copy_cdef(self, env): 116 """ 117 Create a copy of the original c(p)def function for all specialized 118 versions. 119 """ 120 permutations = self.node.type.get_all_specialized_permutations() 121 # print 'Node %s has %d specializations:' % (self.node.entry.name, 122 # len(permutations)) 123 # import pprint; pprint.pprint([d for cname, d in permutations]) 124 125 if self.node.entry in env.cfunc_entries: 126 env.cfunc_entries.remove(self.node.entry) 127 128 # Prevent copying of the python function 129 self.orig_py_func = orig_py_func = self.node.py_func 130 self.node.py_func = None 131 if orig_py_func: 132 env.pyfunc_entries.remove(orig_py_func.entry) 133 134 fused_types = self.node.type.get_fused_types() 135 self.fused_compound_types = fused_types 136 137 for cname, fused_to_specific in permutations: 138 copied_node = copy.deepcopy(self.node) 139 140 # Make the types in our CFuncType specific 141 type = copied_node.type.specialize(fused_to_specific) 142 entry = copied_node.entry 143 144 copied_node.type = type 145 entry.type, type.entry = type, entry 146 147 entry.used = (entry.used or 148 self.node.entry.defined_in_pxd or 149 env.is_c_class_scope or 150 entry.is_cmethod) 151 152 if self.node.cfunc_declarator.optional_arg_count: 153 self.node.cfunc_declarator.declare_optional_arg_struct( 154 type, env, fused_cname=cname) 155 156 copied_node.return_type = type.return_type 157 self.create_new_local_scope(copied_node, env, fused_to_specific) 158 159 # Make the argument types in the CFuncDeclarator specific 160 self._specialize_function_args(copied_node.cfunc_declarator.args, 161 fused_to_specific) 162 163 type.specialize_entry(entry, cname) 164 env.cfunc_entries.append(entry) 165 166 # If a cpdef, declare all specialized cpdefs (this 167 # also calls analyse_declarations) 168 copied_node.declare_cpdef_wrapper(env) 169 if copied_node.py_func: 170 env.pyfunc_entries.remove(copied_node.py_func.entry) 171 172 self.specialize_copied_def( 173 copied_node.py_func, cname, self.node.entry.as_variable, 174 fused_to_specific, fused_types) 175 176 if not self.replace_fused_typechecks(copied_node): 177 break 178 179 if orig_py_func: 180 self.py_func = self.make_fused_cpdef(orig_py_func, env, 181 is_def=False) 182 else: 183 self.py_func = orig_py_func 184 185 def _specialize_function_args(self, args, fused_to_specific): 186 for arg in args: 187 if arg.type.is_fused: 188 arg.type = arg.type.specialize(fused_to_specific) 189 if arg.type.is_memoryviewslice: 190 MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype) 191 192 def create_new_local_scope(self, node, env, f2s): 193 """ 194 Create a new local scope for the copied node and append it to 195 self.nodes. A new local scope is needed because the arguments with the 196 fused types are aready in the local scope, and we need the specialized 197 entries created after analyse_declarations on each specialized version 198 of the (CFunc)DefNode. 199 f2s is a dict mapping each fused type to its specialized version 200 """ 201 node.create_local_scope(env) 202 node.local_scope.fused_to_specific = f2s 203 204 # This is copied from the original function, set it to false to 205 # stop recursion 206 node.has_fused_arguments = False 207 self.nodes.append(node) 208 209 def specialize_copied_def(self, node, cname, py_entry, f2s, fused_types): 210 """Specialize the copy of a DefNode given the copied node, 211 the specialization cname and the original DefNode entry""" 212 type_strings = [ 213 PyrexTypes.specialization_signature_string(fused_type, f2s) 214 for fused_type in fused_types 215 ] 216 217 node.specialized_signature_string = '|'.join(type_strings) 218 219 node.entry.pymethdef_cname = PyrexTypes.get_fused_cname( 220 cname, node.entry.pymethdef_cname) 221 node.entry.doc = py_entry.doc 222 node.entry.doc_cname = py_entry.doc_cname 223 224 def replace_fused_typechecks(self, copied_node): 225 """ 226 Branch-prune fused type checks like 227 228 if fused_t is int: 229 ... 230 231 Returns whether an error was issued and whether we should stop in 232 in order to prevent a flood of errors. 233 """ 234 num_errors = Errors.num_errors 235 transform = ParseTreeTransforms.ReplaceFusedTypeChecks( 236 copied_node.local_scope) 237 transform(copied_node) 238 239 if Errors.num_errors > num_errors: 240 return False 241 242 return True 243 244 def _fused_instance_checks(self, normal_types, pyx_code, env): 245 """ 246 Genereate Cython code for instance checks, matching an object to 247 specialized types. 248 """ 249 if_ = 'if' 250 for specialized_type in normal_types: 251 # all_numeric = all_numeric and specialized_type.is_numeric 252 py_type_name = specialized_type.py_type_name() 253 specialized_type_name = specialized_type.specialization_string 254 pyx_code.context.update(locals()) 255 pyx_code.put_chunk( 256 u""" 257 {{if_}} isinstance(arg, {{py_type_name}}): 258 dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}' 259 """) 260 if_ = 'elif' 261 262 if not normal_types: 263 # we need an 'if' to match the following 'else' 264 pyx_code.putln("if 0: pass") 265 266 def _dtype_name(self, dtype): 267 if dtype.is_typedef: 268 return '___pyx_%s' % dtype 269 return str(dtype).replace(' ', '_') 270 271 def _dtype_type(self, dtype): 272 if dtype.is_typedef: 273 return self._dtype_name(dtype) 274 return str(dtype) 275 276 def _sizeof_dtype(self, dtype): 277 if dtype.is_pyobject: 278 return 'sizeof(void *)' 279 else: 280 return "sizeof(%s)" % self._dtype_type(dtype) 281 282 def _buffer_check_numpy_dtype_setup_cases(self, pyx_code): 283 "Setup some common cases to match dtypes against specializations" 284 if pyx_code.indenter("if dtype.kind in ('i', 'u'):"): 285 pyx_code.putln("pass") 286 pyx_code.named_insertion_point("dtype_int") 287 pyx_code.dedent() 288 289 if pyx_code.indenter("elif dtype.kind == 'f':"): 290 pyx_code.putln("pass") 291 pyx_code.named_insertion_point("dtype_float") 292 pyx_code.dedent() 293 294 if pyx_code.indenter("elif dtype.kind == 'c':"): 295 pyx_code.putln("pass") 296 pyx_code.named_insertion_point("dtype_complex") 297 pyx_code.dedent() 298 299 if pyx_code.indenter("elif dtype.kind == 'O':"): 300 pyx_code.putln("pass") 301 pyx_code.named_insertion_point("dtype_object") 302 pyx_code.dedent() 303 304 match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'" 305 no_match = "dest_sig[{{dest_sig_idx}}] = None" 306 def _buffer_check_numpy_dtype(self, pyx_code, specialized_buffer_types): 307 """ 308 Match a numpy dtype object to the individual specializations. 309 """ 310 self._buffer_check_numpy_dtype_setup_cases(pyx_code) 311 312 for specialized_type in specialized_buffer_types: 313 dtype = specialized_type.dtype 314 pyx_code.context.update( 315 itemsize_match=self._sizeof_dtype(dtype) + " == itemsize", 316 signed_match="not (%s_is_signed ^ dtype_signed)" % self._dtype_name(dtype), 317 dtype=dtype, 318 specialized_type_name=specialized_type.specialization_string) 319 320 dtypes = [ 321 (dtype.is_int, pyx_code.dtype_int), 322 (dtype.is_float, pyx_code.dtype_float), 323 (dtype.is_complex, pyx_code.dtype_complex) 324 ] 325 326 for dtype_category, codewriter in dtypes: 327 if dtype_category: 328 cond = '{{itemsize_match}} and arg.ndim == %d' % ( 329 specialized_type.ndim,) 330 if dtype.is_int: 331 cond += ' and {{signed_match}}' 332 333 if codewriter.indenter("if %s:" % cond): 334 # codewriter.putln("print 'buffer match found based on numpy dtype'") 335 codewriter.putln(self.match) 336 codewriter.putln("break") 337 codewriter.dedent() 338 339 def _buffer_parse_format_string_check(self, pyx_code, decl_code, 340 specialized_type, env): 341 """ 342 For each specialized type, try to coerce the object to a memoryview 343 slice of that type. This means obtaining a buffer and parsing the 344 format string. 345 TODO: separate buffer acquisition from format parsing 346 """ 347 dtype = specialized_type.dtype 348 if specialized_type.is_buffer: 349 axes = [('direct', 'strided')] * specialized_type.ndim 350 else: 351 axes = specialized_type.axes 352 353 memslice_type = PyrexTypes.MemoryViewSliceType(dtype, axes) 354 memslice_type.create_from_py_utility_code(env) 355 pyx_code.context.update( 356 coerce_from_py_func=memslice_type.from_py_function, 357 dtype=dtype) 358 decl_code.putln( 359 "{{memviewslice_cname}} {{coerce_from_py_func}}(object)") 360 361 pyx_code.context.update( 362 specialized_type_name=specialized_type.specialization_string, 363 sizeof_dtype=self._sizeof_dtype(dtype)) 364 365 pyx_code.put_chunk( 366 u""" 367 # try {{dtype}} 368 if itemsize == -1 or itemsize == {{sizeof_dtype}}: 369 memslice = {{coerce_from_py_func}}(arg) 370 if memslice.memview: 371 __PYX_XDEC_MEMVIEW(&memslice, 1) 372 # print 'found a match for the buffer through format parsing' 373 %s 374 break 375 else: 376 __pyx_PyErr_Clear() 377 """ % self.match) 378 379 def _buffer_checks(self, buffer_types, pyx_code, decl_code, env): 380 """ 381 Generate Cython code to match objects to buffer specializations. 382 First try to get a numpy dtype object and match it against the individual 383 specializations. If that fails, try naively to coerce the object 384 to each specialization, which obtains the buffer each time and tries 385 to match the format string. 386 """ 387 from Cython.Compiler import ExprNodes 388 if buffer_types: 389 if pyx_code.indenter(u"else:"): 390 # The first thing to find a match in this loop breaks out of the loop 391 if pyx_code.indenter(u"while 1:"): 392 pyx_code.put_chunk( 393 u""" 394 if numpy is not None: 395 if isinstance(arg, numpy.ndarray): 396 dtype = arg.dtype 397 elif (__pyx_memoryview_check(arg) and 398 isinstance(arg.base, numpy.ndarray)): 399 dtype = arg.base.dtype 400 else: 401 dtype = None 402 403 itemsize = -1 404 if dtype is not None: 405 itemsize = dtype.itemsize 406 kind = ord(dtype.kind) 407 dtype_signed = kind == ord('i') 408 """) 409 pyx_code.indent(2) 410 pyx_code.named_insertion_point("numpy_dtype_checks") 411 self._buffer_check_numpy_dtype(pyx_code, buffer_types) 412 pyx_code.dedent(2) 413 414 for specialized_type in buffer_types: 415 self._buffer_parse_format_string_check( 416 pyx_code, decl_code, specialized_type, env) 417 418 pyx_code.putln(self.no_match) 419 pyx_code.putln("break") 420 pyx_code.dedent() 421 422 pyx_code.dedent() 423 else: 424 pyx_code.putln("else: %s" % self.no_match) 425 426 def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types): 427 """ 428 If we have any buffer specializations, write out some variable 429 declarations and imports. 430 """ 431 decl_code.put_chunk( 432 u""" 433 ctypedef struct {{memviewslice_cname}}: 434 void *memview 435 436 void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil) 437 bint __pyx_memoryview_check(object) 438 """) 439 440 pyx_code.local_variable_declarations.put_chunk( 441 u""" 442 cdef {{memviewslice_cname}} memslice 443 cdef Py_ssize_t itemsize 444 cdef bint dtype_signed 445 cdef char kind 446 447 itemsize = -1 448 """) 449 450 pyx_code.imports.put_chunk( 451 u""" 452 try: 453 import numpy 454 except ImportError: 455 numpy = None 456 """) 457 458 seen_int_dtypes = set() 459 for buffer_type in all_buffer_types: 460 dtype = buffer_type.dtype 461 if dtype.is_typedef: 462 #decl_code.putln("ctypedef %s %s" % (dtype.resolve(), 463 # self._dtype_name(dtype))) 464 decl_code.putln('ctypedef %s %s "%s"' % (dtype.resolve(), 465 self._dtype_name(dtype), 466 dtype.declaration_code(""))) 467 468 if buffer_type.dtype.is_int: 469 if str(dtype) not in seen_int_dtypes: 470 seen_int_dtypes.add(str(dtype)) 471 pyx_code.context.update(dtype_name=self._dtype_name(dtype), 472 dtype_type=self._dtype_type(dtype)) 473 pyx_code.local_variable_declarations.put_chunk( 474 u""" 475 cdef bint {{dtype_name}}_is_signed 476 {{dtype_name}}_is_signed = <{{dtype_type}}> -1 < 0 477 """) 478 479 def _split_fused_types(self, arg): 480 """ 481 Specialize fused types and split into normal types and buffer types. 482 """ 483 specialized_types = PyrexTypes.get_specialized_types(arg.type) 484 # Prefer long over int, etc 485 # specialized_types.sort() 486 seen_py_type_names = set() 487 normal_types, buffer_types = [], [] 488 for specialized_type in specialized_types: 489 py_type_name = specialized_type.py_type_name() 490 if py_type_name: 491 if py_type_name in seen_py_type_names: 492 continue 493 seen_py_type_names.add(py_type_name) 494 normal_types.append(specialized_type) 495 elif specialized_type.is_buffer or specialized_type.is_memoryviewslice: 496 buffer_types.append(specialized_type) 497 498 return normal_types, buffer_types 499 500 def _unpack_argument(self, pyx_code): 501 pyx_code.put_chunk( 502 u""" 503 # PROCESSING ARGUMENT {{arg_tuple_idx}} 504 if {{arg_tuple_idx}} < len(args): 505 arg = args[{{arg_tuple_idx}}] 506 elif '{{arg.name}}' in kwargs: 507 arg = kwargs['{{arg.name}}'] 508 else: 509 {{if arg.default:}} 510 arg = defaults[{{default_idx}}] 511 {{else}} 512 raise TypeError("Expected at least %d arguments" % len(args)) 513 {{endif}} 514 """) 515 516 def make_fused_cpdef(self, orig_py_func, env, is_def): 517 """ 518 This creates the function that is indexable from Python and does 519 runtime dispatch based on the argument types. The function gets the 520 arg tuple and kwargs dict (or None) and the defaults tuple 521 as arguments from the Binding Fused Function's tp_call. 522 """ 523 from Cython.Compiler import TreeFragment, Code, MemoryView, UtilityCode 524 525 # { (arg_pos, FusedType) : specialized_type } 526 seen_fused_types = set() 527 528 context = { 529 'memviewslice_cname': MemoryView.memviewslice_cname, 530 'func_args': self.node.args, 531 'n_fused': len([arg for arg in self.node.args]), 532 'name': orig_py_func.entry.name, 533 } 534 535 pyx_code = Code.PyxCodeWriter(context=context) 536 decl_code = Code.PyxCodeWriter(context=context) 537 decl_code.put_chunk( 538 u""" 539 cdef extern from *: 540 void __pyx_PyErr_Clear "PyErr_Clear" () 541 """) 542 decl_code.indent() 543 544 pyx_code.put_chunk( 545 u""" 546 def __pyx_fused_cpdef(signatures, args, kwargs, defaults): 547 dest_sig = [{{for _ in range(n_fused)}}None,{{endfor}}] 548 549 if kwargs is None: 550 kwargs = {} 551 552 cdef Py_ssize_t i 553 554 # instance check body 555 """) 556 pyx_code.indent() # indent following code to function body 557 pyx_code.named_insertion_point("imports") 558 pyx_code.named_insertion_point("local_variable_declarations") 559 560 fused_index = 0 561 default_idx = 0 562 all_buffer_types = set() 563 for i, arg in enumerate(self.node.args): 564 if arg.type.is_fused and arg.type not in seen_fused_types: 565 seen_fused_types.add(arg.type) 566 567 context.update( 568 arg_tuple_idx=i, 569 arg=arg, 570 dest_sig_idx=fused_index, 571 default_idx=default_idx, 572 ) 573 574 normal_types, buffer_types = self._split_fused_types(arg) 575 self._unpack_argument(pyx_code) 576 self._fused_instance_checks(normal_types, pyx_code, env) 577 self._buffer_checks(buffer_types, pyx_code, decl_code, env) 578 fused_index += 1 579 580 all_buffer_types.update(buffer_types) 581 582 if arg.default: 583 default_idx += 1 584 585 if all_buffer_types: 586 self._buffer_declarations(pyx_code, decl_code, all_buffer_types) 587 env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c")) 588 589 pyx_code.put_chunk( 590 u""" 591 candidates = [] 592 for sig in signatures: 593 match_found = False 594 for src_type, dst_type in zip(sig.strip('()').split('|'), dest_sig): 595 if dst_type is not None: 596 if src_type == dst_type: 597 match_found = True 598 else: 599 match_found = False 600 break 601 602 if match_found: 603 candidates.append(sig) 604 605 if not candidates: 606 raise TypeError("No matching signature found") 607 elif len(candidates) > 1: 608 raise TypeError("Function call with ambiguous argument types") 609 else: 610 return signatures[candidates[0]] 611 """) 612 613 fragment_code = pyx_code.getvalue() 614 # print decl_code.getvalue() 615 # print fragment_code 616 fragment = TreeFragment.TreeFragment(fragment_code, level='module') 617 ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root) 618 UtilityCode.declare_declarations_in_scope(decl_code.getvalue(), 619 env.global_scope()) 620 ast.scope = env 621 ast.analyse_declarations(env) 622 py_func = ast.stats[-1] # the DefNode 623 self.fragment_scope = ast.scope 624 625 if isinstance(self.node, DefNode): 626 py_func.specialized_cpdefs = self.nodes[:] 627 else: 628 py_func.specialized_cpdefs = [n.py_func for n in self.nodes] 629 630 return py_func 631 632 def update_fused_defnode_entry(self, env): 633 copy_attributes = ( 634 'name', 'pos', 'cname', 'func_cname', 'pyfunc_cname', 635 'pymethdef_cname', 'doc', 'doc_cname', 'is_member', 636 'scope' 637 ) 638 639 entry = self.py_func.entry 640 641 for attr in copy_attributes: 642 setattr(entry, attr, 643 getattr(self.orig_py_func.entry, attr)) 644 645 self.py_func.name = self.orig_py_func.name 646 self.py_func.doc = self.orig_py_func.doc 647 648 env.entries.pop('__pyx_fused_cpdef', None) 649 if isinstance(self.node, DefNode): 650 env.entries[entry.name] = entry 651 else: 652 env.entries[entry.name].as_variable = entry 653 654 env.pyfunc_entries.append(entry) 655 656 self.py_func.entry.fused_cfunction = self 657 for node in self.nodes: 658 if isinstance(self.node, DefNode): 659 node.fused_py_func = self.py_func 660 else: 661 node.py_func.fused_py_func = self.py_func 662 node.entry.as_variable = entry 663 664 self.synthesize_defnodes() 665 self.stats.append(self.__signatures__) 666 667 def analyse_expressions(self, env): 668 """ 669 Analyse the expressions. Take care to only evaluate default arguments 670 once and clone the result for all specializations 671 """ 672 for fused_compound_type in self.fused_compound_types: 673 for fused_type in fused_compound_type.get_fused_types(): 674 for specialization_type in fused_type.types: 675 if specialization_type.is_complex: 676 specialization_type.create_declaration_utility_code(env) 677 678 if self.py_func: 679 self.__signatures__ = self.__signatures__.analyse_expressions(env) 680 self.py_func = self.py_func.analyse_expressions(env) 681 self.resulting_fused_function = self.resulting_fused_function.analyse_expressions(env) 682 self.fused_func_assignment = self.fused_func_assignment.analyse_expressions(env) 683 684 self.defaults = defaults = [] 685 686 for arg in self.node.args: 687 if arg.default: 688 arg.default = arg.default.analyse_expressions(env) 689 defaults.append(ProxyNode(arg.default)) 690 else: 691 defaults.append(None) 692 693 for i, stat in enumerate(self.stats): 694 stat = self.stats[i] = stat.analyse_expressions(env) 695 if isinstance(stat, FuncDefNode): 696 for arg, default in zip(stat.args, defaults): 697 if default is not None: 698 arg.default = CloneNode(default).coerce_to(arg.type, env) 699 700 if self.py_func: 701 args = [CloneNode(default) for default in defaults if default] 702 self.defaults_tuple = TupleNode(self.pos, args=args) 703 self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_children=True) 704 self.defaults_tuple = ProxyNode(self.defaults_tuple) 705 self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object) 706 707 fused_func = self.resulting_fused_function.arg 708 fused_func.defaults_tuple = CloneNode(self.defaults_tuple) 709 fused_func.code_object = CloneNode(self.code_object) 710 711 for i, pycfunc in enumerate(self.specialized_pycfuncs): 712 pycfunc.code_object = CloneNode(self.code_object) 713 pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(env) 714 pycfunc.defaults_tuple = CloneNode(self.defaults_tuple) 715 return self 716 717 def synthesize_defnodes(self): 718 """ 719 Create the __signatures__ dict of PyCFunctionNode specializations. 720 """ 721 if isinstance(self.nodes[0], CFuncDefNode): 722 nodes = [node.py_func for node in self.nodes] 723 else: 724 nodes = self.nodes 725 726 signatures = [ 727 StringEncoding.EncodedString(node.specialized_signature_string) 728 for node in nodes] 729 keys = [ExprNodes.StringNode(node.pos, value=sig) 730 for node, sig in zip(nodes, signatures)] 731 values = [ExprNodes.PyCFunctionNode.from_defnode(node, True) 732 for node in nodes] 733 self.__signatures__ = ExprNodes.DictNode.from_pairs(self.pos, 734 zip(keys, values)) 735 736 self.specialized_pycfuncs = values 737 for pycfuncnode in values: 738 pycfuncnode.is_specialization = True 739 740 def generate_function_definitions(self, env, code): 741 if self.py_func: 742 self.py_func.pymethdef_required = True 743 self.fused_func_assignment.generate_function_definitions(env, code) 744 745 for stat in self.stats: 746 if isinstance(stat, FuncDefNode) and stat.entry.used: 747 code.mark_pos(stat.pos) 748 stat.generate_function_definitions(env, code) 749 750 def generate_execution_code(self, code): 751 # Note: all def function specialization are wrapped in PyCFunction 752 # nodes in the self.__signatures__ dictnode. 753 for default in self.defaults: 754 if default is not None: 755 default.generate_evaluation_code(code) 756 757 if self.py_func: 758 self.defaults_tuple.generate_evaluation_code(code) 759 self.code_object.generate_evaluation_code(code) 760 761 for stat in self.stats: 762 code.mark_pos(stat.pos) 763 if isinstance(stat, ExprNodes.ExprNode): 764 stat.generate_evaluation_code(code) 765 else: 766 stat.generate_execution_code(code) 767 768 if self.__signatures__: 769 self.resulting_fused_function.generate_evaluation_code(code) 770 771 code.putln( 772 "((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" % 773 (self.resulting_fused_function.result(), 774 self.__signatures__.result())) 775 code.put_giveref(self.__signatures__.result()) 776 777 self.fused_func_assignment.generate_execution_code(code) 778 779 # Dispose of results 780 self.resulting_fused_function.generate_disposal_code(code) 781 self.defaults_tuple.generate_disposal_code(code) 782 self.code_object.generate_disposal_code(code) 783 784 for default in self.defaults: 785 if default is not None: 786 default.generate_disposal_code(code) 787 788 def annotate(self, code): 789 for stat in self.stats: 790 stat.annotate(code) 791