1import copy
2
3from Cython.Compiler import (ExprNodes, PyrexTypes, MemoryView,
4                             ParseTreeTransforms, StringEncoding,
5                             Errors)
6from Cython.Compiler.ExprNodes import CloneNode, ProxyNode, TupleNode
7from Cython.Compiler.Nodes import (FuncDefNode, CFuncDefNode, StatListNode,
8                                   DefNode)
9
10class FusedCFuncDefNode(StatListNode):
11    """
12    This node replaces a function with fused arguments. It deep-copies the
13    function for every permutation of fused types, and allocates a new local
14    scope for it. It keeps track of the original function in self.node, and
15    the entry of the original function in the symbol table is given the
16    'fused_cfunction' attribute which points back to us.
17    Then when a function lookup occurs (to e.g. call it), the call can be
18    dispatched to the right function.
19
20    node    FuncDefNode    the original function
21    nodes   [FuncDefNode]  list of copies of node with different specific types
22    py_func DefNode        the fused python function subscriptable from
23                           Python space
24    __signatures__         A DictNode mapping signature specialization strings
25                           to PyCFunction nodes
26    resulting_fused_function  PyCFunction for the fused DefNode that delegates
27                              to specializations
28    fused_func_assignment   Assignment of the fused function to the function name
29    defaults_tuple          TupleNode of defaults (letting PyCFunctionNode build
30                            defaults would result in many different tuples)
31    specialized_pycfuncs    List of synthesized pycfunction nodes for the
32                            specializations
33    code_object             CodeObjectNode shared by all specializations and the
34                            fused function
35
36    fused_compound_types    All fused (compound) types (e.g. floating[:])
37    """
38
39    __signatures__ = None
40    resulting_fused_function = None
41    fused_func_assignment = None
42    defaults_tuple = None
43    decorators = None
44
45    child_attrs = StatListNode.child_attrs + [
46        '__signatures__', 'resulting_fused_function', 'fused_func_assignment']
47
48    def __init__(self, node, env):
49        super(FusedCFuncDefNode, self).__init__(node.pos)
50
51        self.nodes = []
52        self.node = node
53
54        is_def = isinstance(self.node, DefNode)
55        if is_def:
56            # self.node.decorators = []
57            self.copy_def(env)
58        else:
59            self.copy_cdef(env)
60
61        # Perform some sanity checks. If anything fails, it's a bug
62        for n in self.nodes:
63            assert not n.entry.type.is_fused
64            assert not n.local_scope.return_type.is_fused
65            if node.return_type.is_fused:
66                assert not n.return_type.is_fused
67
68            if not is_def and n.cfunc_declarator.optional_arg_count:
69                assert n.type.op_arg_struct
70
71        node.entry.fused_cfunction = self
72        # Copy the nodes as AnalyseDeclarationsTransform will prepend
73        # self.py_func to self.stats, as we only want specialized
74        # CFuncDefNodes in self.nodes
75        self.stats = self.nodes[:]
76
77    def copy_def(self, env):
78        """
79        Create a copy of the original def or lambda function for specialized
80        versions.
81        """
82        fused_compound_types = PyrexTypes.unique(
83            [arg.type for arg in self.node.args if arg.type.is_fused])
84        permutations = PyrexTypes.get_all_specialized_permutations(fused_compound_types)
85
86        self.fused_compound_types = fused_compound_types
87
88        if self.node.entry in env.pyfunc_entries:
89            env.pyfunc_entries.remove(self.node.entry)
90
91        for cname, fused_to_specific in permutations:
92            copied_node = copy.deepcopy(self.node)
93
94            self._specialize_function_args(copied_node.args, fused_to_specific)
95            copied_node.return_type = self.node.return_type.specialize(
96                                                    fused_to_specific)
97
98            copied_node.analyse_declarations(env)
99            # copied_node.is_staticmethod = self.node.is_staticmethod
100            # copied_node.is_classmethod = self.node.is_classmethod
101            self.create_new_local_scope(copied_node, env, fused_to_specific)
102            self.specialize_copied_def(copied_node, cname, self.node.entry,
103                                       fused_to_specific, fused_compound_types)
104
105            PyrexTypes.specialize_entry(copied_node.entry, cname)
106            copied_node.entry.used = True
107            env.entries[copied_node.entry.name] = copied_node.entry
108
109            if not self.replace_fused_typechecks(copied_node):
110                break
111
112        self.orig_py_func = self.node
113        self.py_func = self.make_fused_cpdef(self.node, env, is_def=True)
114
115    def copy_cdef(self, env):
116        """
117        Create a copy of the original c(p)def function for all specialized
118        versions.
119        """
120        permutations = self.node.type.get_all_specialized_permutations()
121        # print 'Node %s has %d specializations:' % (self.node.entry.name,
122        #                                            len(permutations))
123        # import pprint; pprint.pprint([d for cname, d in permutations])
124
125        if self.node.entry in env.cfunc_entries:
126            env.cfunc_entries.remove(self.node.entry)
127
128        # Prevent copying of the python function
129        self.orig_py_func = orig_py_func = self.node.py_func
130        self.node.py_func = None
131        if orig_py_func:
132            env.pyfunc_entries.remove(orig_py_func.entry)
133
134        fused_types = self.node.type.get_fused_types()
135        self.fused_compound_types = fused_types
136
137        for cname, fused_to_specific in permutations:
138            copied_node = copy.deepcopy(self.node)
139
140            # Make the types in our CFuncType specific
141            type = copied_node.type.specialize(fused_to_specific)
142            entry = copied_node.entry
143
144            copied_node.type = type
145            entry.type, type.entry = type, entry
146
147            entry.used = (entry.used or
148                          self.node.entry.defined_in_pxd or
149                          env.is_c_class_scope or
150                          entry.is_cmethod)
151
152            if self.node.cfunc_declarator.optional_arg_count:
153                self.node.cfunc_declarator.declare_optional_arg_struct(
154                                           type, env, fused_cname=cname)
155
156            copied_node.return_type = type.return_type
157            self.create_new_local_scope(copied_node, env, fused_to_specific)
158
159            # Make the argument types in the CFuncDeclarator specific
160            self._specialize_function_args(copied_node.cfunc_declarator.args,
161                                           fused_to_specific)
162
163            type.specialize_entry(entry, cname)
164            env.cfunc_entries.append(entry)
165
166            # If a cpdef, declare all specialized cpdefs (this
167            # also calls analyse_declarations)
168            copied_node.declare_cpdef_wrapper(env)
169            if copied_node.py_func:
170                env.pyfunc_entries.remove(copied_node.py_func.entry)
171
172                self.specialize_copied_def(
173                        copied_node.py_func, cname, self.node.entry.as_variable,
174                        fused_to_specific, fused_types)
175
176            if not self.replace_fused_typechecks(copied_node):
177                break
178
179        if orig_py_func:
180            self.py_func = self.make_fused_cpdef(orig_py_func, env,
181                                                 is_def=False)
182        else:
183            self.py_func = orig_py_func
184
185    def _specialize_function_args(self, args, fused_to_specific):
186        for arg in args:
187            if arg.type.is_fused:
188                arg.type = arg.type.specialize(fused_to_specific)
189                if arg.type.is_memoryviewslice:
190                    MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype)
191
192    def create_new_local_scope(self, node, env, f2s):
193        """
194        Create a new local scope for the copied node and append it to
195        self.nodes. A new local scope is needed because the arguments with the
196        fused types are aready in the local scope, and we need the specialized
197        entries created after analyse_declarations on each specialized version
198        of the (CFunc)DefNode.
199        f2s is a dict mapping each fused type to its specialized version
200        """
201        node.create_local_scope(env)
202        node.local_scope.fused_to_specific = f2s
203
204        # This is copied from the original function, set it to false to
205        # stop recursion
206        node.has_fused_arguments = False
207        self.nodes.append(node)
208
209    def specialize_copied_def(self, node, cname, py_entry, f2s, fused_types):
210        """Specialize the copy of a DefNode given the copied node,
211        the specialization cname and the original DefNode entry"""
212        type_strings = [
213            PyrexTypes.specialization_signature_string(fused_type, f2s)
214                for fused_type in fused_types
215        ]
216
217        node.specialized_signature_string = '|'.join(type_strings)
218
219        node.entry.pymethdef_cname = PyrexTypes.get_fused_cname(
220                                        cname, node.entry.pymethdef_cname)
221        node.entry.doc = py_entry.doc
222        node.entry.doc_cname = py_entry.doc_cname
223
224    def replace_fused_typechecks(self, copied_node):
225        """
226        Branch-prune fused type checks like
227
228            if fused_t is int:
229                ...
230
231        Returns whether an error was issued and whether we should stop in
232        in order to prevent a flood of errors.
233        """
234        num_errors = Errors.num_errors
235        transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
236                                       copied_node.local_scope)
237        transform(copied_node)
238
239        if Errors.num_errors > num_errors:
240            return False
241
242        return True
243
244    def _fused_instance_checks(self, normal_types, pyx_code, env):
245        """
246        Genereate Cython code for instance checks, matching an object to
247        specialized types.
248        """
249        if_ = 'if'
250        for specialized_type in normal_types:
251            # all_numeric = all_numeric and specialized_type.is_numeric
252            py_type_name = specialized_type.py_type_name()
253            specialized_type_name = specialized_type.specialization_string
254            pyx_code.context.update(locals())
255            pyx_code.put_chunk(
256                u"""
257                    {{if_}} isinstance(arg, {{py_type_name}}):
258                        dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'
259                """)
260            if_ = 'elif'
261
262        if not normal_types:
263            # we need an 'if' to match the following 'else'
264            pyx_code.putln("if 0: pass")
265
266    def _dtype_name(self, dtype):
267        if dtype.is_typedef:
268            return '___pyx_%s' % dtype
269        return str(dtype).replace(' ', '_')
270
271    def _dtype_type(self, dtype):
272        if dtype.is_typedef:
273            return self._dtype_name(dtype)
274        return str(dtype)
275
276    def _sizeof_dtype(self, dtype):
277        if dtype.is_pyobject:
278            return 'sizeof(void *)'
279        else:
280            return "sizeof(%s)" % self._dtype_type(dtype)
281
282    def _buffer_check_numpy_dtype_setup_cases(self, pyx_code):
283        "Setup some common cases to match dtypes against specializations"
284        if pyx_code.indenter("if dtype.kind in ('i', 'u'):"):
285            pyx_code.putln("pass")
286            pyx_code.named_insertion_point("dtype_int")
287            pyx_code.dedent()
288
289        if pyx_code.indenter("elif dtype.kind == 'f':"):
290            pyx_code.putln("pass")
291            pyx_code.named_insertion_point("dtype_float")
292            pyx_code.dedent()
293
294        if pyx_code.indenter("elif dtype.kind == 'c':"):
295            pyx_code.putln("pass")
296            pyx_code.named_insertion_point("dtype_complex")
297            pyx_code.dedent()
298
299        if pyx_code.indenter("elif dtype.kind == 'O':"):
300            pyx_code.putln("pass")
301            pyx_code.named_insertion_point("dtype_object")
302            pyx_code.dedent()
303
304    match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'"
305    no_match = "dest_sig[{{dest_sig_idx}}] = None"
306    def _buffer_check_numpy_dtype(self, pyx_code, specialized_buffer_types):
307        """
308        Match a numpy dtype object to the individual specializations.
309        """
310        self._buffer_check_numpy_dtype_setup_cases(pyx_code)
311
312        for specialized_type in specialized_buffer_types:
313            dtype = specialized_type.dtype
314            pyx_code.context.update(
315                itemsize_match=self._sizeof_dtype(dtype) + " == itemsize",
316                signed_match="not (%s_is_signed ^ dtype_signed)" % self._dtype_name(dtype),
317                dtype=dtype,
318                specialized_type_name=specialized_type.specialization_string)
319
320            dtypes = [
321                (dtype.is_int, pyx_code.dtype_int),
322                (dtype.is_float, pyx_code.dtype_float),
323                (dtype.is_complex, pyx_code.dtype_complex)
324            ]
325
326            for dtype_category, codewriter in dtypes:
327                if dtype_category:
328                    cond = '{{itemsize_match}} and arg.ndim == %d' % (
329                                                    specialized_type.ndim,)
330                    if dtype.is_int:
331                        cond += ' and {{signed_match}}'
332
333                    if codewriter.indenter("if %s:" % cond):
334                        # codewriter.putln("print 'buffer match found based on numpy dtype'")
335                        codewriter.putln(self.match)
336                        codewriter.putln("break")
337                        codewriter.dedent()
338
339    def _buffer_parse_format_string_check(self, pyx_code, decl_code,
340                                          specialized_type, env):
341        """
342        For each specialized type, try to coerce the object to a memoryview
343        slice of that type. This means obtaining a buffer and parsing the
344        format string.
345        TODO: separate buffer acquisition from format parsing
346        """
347        dtype = specialized_type.dtype
348        if specialized_type.is_buffer:
349            axes = [('direct', 'strided')] * specialized_type.ndim
350        else:
351            axes = specialized_type.axes
352
353        memslice_type = PyrexTypes.MemoryViewSliceType(dtype, axes)
354        memslice_type.create_from_py_utility_code(env)
355        pyx_code.context.update(
356            coerce_from_py_func=memslice_type.from_py_function,
357            dtype=dtype)
358        decl_code.putln(
359            "{{memviewslice_cname}} {{coerce_from_py_func}}(object)")
360
361        pyx_code.context.update(
362            specialized_type_name=specialized_type.specialization_string,
363            sizeof_dtype=self._sizeof_dtype(dtype))
364
365        pyx_code.put_chunk(
366            u"""
367                # try {{dtype}}
368                if itemsize == -1 or itemsize == {{sizeof_dtype}}:
369                    memslice = {{coerce_from_py_func}}(arg)
370                    if memslice.memview:
371                        __PYX_XDEC_MEMVIEW(&memslice, 1)
372                        # print 'found a match for the buffer through format parsing'
373                        %s
374                        break
375                    else:
376                        __pyx_PyErr_Clear()
377            """ % self.match)
378
379    def _buffer_checks(self, buffer_types, pyx_code, decl_code, env):
380        """
381        Generate Cython code to match objects to buffer specializations.
382        First try to get a numpy dtype object and match it against the individual
383        specializations. If that fails, try naively to coerce the object
384        to each specialization, which obtains the buffer each time and tries
385        to match the format string.
386        """
387        from Cython.Compiler import ExprNodes
388        if buffer_types:
389            if pyx_code.indenter(u"else:"):
390                # The first thing to find a match in this loop breaks out of the loop
391                if pyx_code.indenter(u"while 1:"):
392                    pyx_code.put_chunk(
393                        u"""
394                            if numpy is not None:
395                                if isinstance(arg, numpy.ndarray):
396                                    dtype = arg.dtype
397                                elif (__pyx_memoryview_check(arg) and
398                                      isinstance(arg.base, numpy.ndarray)):
399                                    dtype = arg.base.dtype
400                                else:
401                                    dtype = None
402
403                                itemsize = -1
404                                if dtype is not None:
405                                    itemsize = dtype.itemsize
406                                    kind = ord(dtype.kind)
407                                    dtype_signed = kind == ord('i')
408                        """)
409                    pyx_code.indent(2)
410                    pyx_code.named_insertion_point("numpy_dtype_checks")
411                    self._buffer_check_numpy_dtype(pyx_code, buffer_types)
412                    pyx_code.dedent(2)
413
414                    for specialized_type in buffer_types:
415                        self._buffer_parse_format_string_check(
416                                pyx_code, decl_code, specialized_type, env)
417
418                    pyx_code.putln(self.no_match)
419                    pyx_code.putln("break")
420                    pyx_code.dedent()
421
422                pyx_code.dedent()
423        else:
424            pyx_code.putln("else: %s" % self.no_match)
425
426    def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types):
427        """
428        If we have any buffer specializations, write out some variable
429        declarations and imports.
430        """
431        decl_code.put_chunk(
432            u"""
433                ctypedef struct {{memviewslice_cname}}:
434                    void *memview
435
436                void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil)
437                bint __pyx_memoryview_check(object)
438            """)
439
440        pyx_code.local_variable_declarations.put_chunk(
441            u"""
442                cdef {{memviewslice_cname}} memslice
443                cdef Py_ssize_t itemsize
444                cdef bint dtype_signed
445                cdef char kind
446
447                itemsize = -1
448            """)
449
450        pyx_code.imports.put_chunk(
451            u"""
452                try:
453                    import numpy
454                except ImportError:
455                    numpy = None
456            """)
457
458        seen_int_dtypes = set()
459        for buffer_type in all_buffer_types:
460            dtype = buffer_type.dtype
461            if dtype.is_typedef:
462                 #decl_code.putln("ctypedef %s %s" % (dtype.resolve(),
463                 #                                    self._dtype_name(dtype)))
464                decl_code.putln('ctypedef %s %s "%s"' % (dtype.resolve(),
465                                                         self._dtype_name(dtype),
466                                                         dtype.declaration_code("")))
467
468            if buffer_type.dtype.is_int:
469                if str(dtype) not in seen_int_dtypes:
470                    seen_int_dtypes.add(str(dtype))
471                    pyx_code.context.update(dtype_name=self._dtype_name(dtype),
472                                            dtype_type=self._dtype_type(dtype))
473                    pyx_code.local_variable_declarations.put_chunk(
474                        u"""
475                            cdef bint {{dtype_name}}_is_signed
476                            {{dtype_name}}_is_signed = <{{dtype_type}}> -1 < 0
477                        """)
478
479    def _split_fused_types(self, arg):
480        """
481        Specialize fused types and split into normal types and buffer types.
482        """
483        specialized_types = PyrexTypes.get_specialized_types(arg.type)
484        # Prefer long over int, etc
485        # specialized_types.sort()
486        seen_py_type_names = set()
487        normal_types, buffer_types = [], []
488        for specialized_type in specialized_types:
489            py_type_name = specialized_type.py_type_name()
490            if py_type_name:
491                if py_type_name in seen_py_type_names:
492                    continue
493                seen_py_type_names.add(py_type_name)
494                normal_types.append(specialized_type)
495            elif specialized_type.is_buffer or specialized_type.is_memoryviewslice:
496                buffer_types.append(specialized_type)
497
498        return normal_types, buffer_types
499
500    def _unpack_argument(self, pyx_code):
501        pyx_code.put_chunk(
502            u"""
503                # PROCESSING ARGUMENT {{arg_tuple_idx}}
504                if {{arg_tuple_idx}} < len(args):
505                    arg = args[{{arg_tuple_idx}}]
506                elif '{{arg.name}}' in kwargs:
507                    arg = kwargs['{{arg.name}}']
508                else:
509                {{if arg.default:}}
510                    arg = defaults[{{default_idx}}]
511                {{else}}
512                    raise TypeError("Expected at least %d arguments" % len(args))
513                {{endif}}
514            """)
515
516    def make_fused_cpdef(self, orig_py_func, env, is_def):
517        """
518        This creates the function that is indexable from Python and does
519        runtime dispatch based on the argument types. The function gets the
520        arg tuple and kwargs dict (or None) and the defaults tuple
521        as arguments from the Binding Fused Function's tp_call.
522        """
523        from Cython.Compiler import TreeFragment, Code, MemoryView, UtilityCode
524
525        # { (arg_pos, FusedType) : specialized_type }
526        seen_fused_types = set()
527
528        context = {
529            'memviewslice_cname': MemoryView.memviewslice_cname,
530            'func_args': self.node.args,
531            'n_fused': len([arg for arg in self.node.args]),
532            'name': orig_py_func.entry.name,
533        }
534
535        pyx_code = Code.PyxCodeWriter(context=context)
536        decl_code = Code.PyxCodeWriter(context=context)
537        decl_code.put_chunk(
538            u"""
539                cdef extern from *:
540                    void __pyx_PyErr_Clear "PyErr_Clear" ()
541            """)
542        decl_code.indent()
543
544        pyx_code.put_chunk(
545            u"""
546                def __pyx_fused_cpdef(signatures, args, kwargs, defaults):
547                    dest_sig = [{{for _ in range(n_fused)}}None,{{endfor}}]
548
549                    if kwargs is None:
550                        kwargs = {}
551
552                    cdef Py_ssize_t i
553
554                    # instance check body
555            """)
556        pyx_code.indent() # indent following code to function body
557        pyx_code.named_insertion_point("imports")
558        pyx_code.named_insertion_point("local_variable_declarations")
559
560        fused_index = 0
561        default_idx = 0
562        all_buffer_types = set()
563        for i, arg in enumerate(self.node.args):
564            if arg.type.is_fused and arg.type not in seen_fused_types:
565                seen_fused_types.add(arg.type)
566
567                context.update(
568                    arg_tuple_idx=i,
569                    arg=arg,
570                    dest_sig_idx=fused_index,
571                    default_idx=default_idx,
572                )
573
574                normal_types, buffer_types = self._split_fused_types(arg)
575                self._unpack_argument(pyx_code)
576                self._fused_instance_checks(normal_types, pyx_code, env)
577                self._buffer_checks(buffer_types, pyx_code, decl_code, env)
578                fused_index += 1
579
580                all_buffer_types.update(buffer_types)
581
582            if arg.default:
583                default_idx += 1
584
585        if all_buffer_types:
586            self._buffer_declarations(pyx_code, decl_code, all_buffer_types)
587            env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c"))
588
589        pyx_code.put_chunk(
590            u"""
591                candidates = []
592                for sig in signatures:
593                    match_found = False
594                    for src_type, dst_type in zip(sig.strip('()').split('|'), dest_sig):
595                        if dst_type is not None:
596                            if src_type == dst_type:
597                                match_found = True
598                            else:
599                                match_found = False
600                                break
601
602                    if match_found:
603                        candidates.append(sig)
604
605                if not candidates:
606                    raise TypeError("No matching signature found")
607                elif len(candidates) > 1:
608                    raise TypeError("Function call with ambiguous argument types")
609                else:
610                    return signatures[candidates[0]]
611            """)
612
613        fragment_code = pyx_code.getvalue()
614        # print decl_code.getvalue()
615        # print fragment_code
616        fragment = TreeFragment.TreeFragment(fragment_code, level='module')
617        ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root)
618        UtilityCode.declare_declarations_in_scope(decl_code.getvalue(),
619                                                  env.global_scope())
620        ast.scope = env
621        ast.analyse_declarations(env)
622        py_func = ast.stats[-1] # the DefNode
623        self.fragment_scope = ast.scope
624
625        if isinstance(self.node, DefNode):
626            py_func.specialized_cpdefs = self.nodes[:]
627        else:
628            py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
629
630        return py_func
631
632    def update_fused_defnode_entry(self, env):
633        copy_attributes = (
634            'name', 'pos', 'cname', 'func_cname', 'pyfunc_cname',
635            'pymethdef_cname', 'doc', 'doc_cname', 'is_member',
636            'scope'
637        )
638
639        entry = self.py_func.entry
640
641        for attr in copy_attributes:
642            setattr(entry, attr,
643                    getattr(self.orig_py_func.entry, attr))
644
645        self.py_func.name = self.orig_py_func.name
646        self.py_func.doc = self.orig_py_func.doc
647
648        env.entries.pop('__pyx_fused_cpdef', None)
649        if isinstance(self.node, DefNode):
650            env.entries[entry.name] = entry
651        else:
652            env.entries[entry.name].as_variable = entry
653
654        env.pyfunc_entries.append(entry)
655
656        self.py_func.entry.fused_cfunction = self
657        for node in self.nodes:
658            if isinstance(self.node, DefNode):
659                node.fused_py_func = self.py_func
660            else:
661                node.py_func.fused_py_func = self.py_func
662                node.entry.as_variable = entry
663
664        self.synthesize_defnodes()
665        self.stats.append(self.__signatures__)
666
667    def analyse_expressions(self, env):
668        """
669        Analyse the expressions. Take care to only evaluate default arguments
670        once and clone the result for all specializations
671        """
672        for fused_compound_type in self.fused_compound_types:
673            for fused_type in fused_compound_type.get_fused_types():
674                for specialization_type in fused_type.types:
675                    if specialization_type.is_complex:
676                        specialization_type.create_declaration_utility_code(env)
677
678        if self.py_func:
679            self.__signatures__ = self.__signatures__.analyse_expressions(env)
680            self.py_func = self.py_func.analyse_expressions(env)
681            self.resulting_fused_function = self.resulting_fused_function.analyse_expressions(env)
682            self.fused_func_assignment = self.fused_func_assignment.analyse_expressions(env)
683
684        self.defaults = defaults = []
685
686        for arg in self.node.args:
687            if arg.default:
688                arg.default = arg.default.analyse_expressions(env)
689                defaults.append(ProxyNode(arg.default))
690            else:
691                defaults.append(None)
692
693        for i, stat in enumerate(self.stats):
694            stat = self.stats[i] = stat.analyse_expressions(env)
695            if isinstance(stat, FuncDefNode):
696                for arg, default in zip(stat.args, defaults):
697                    if default is not None:
698                        arg.default = CloneNode(default).coerce_to(arg.type, env)
699
700        if self.py_func:
701            args = [CloneNode(default) for default in defaults if default]
702            self.defaults_tuple = TupleNode(self.pos, args=args)
703            self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_children=True)
704            self.defaults_tuple = ProxyNode(self.defaults_tuple)
705            self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object)
706
707            fused_func = self.resulting_fused_function.arg
708            fused_func.defaults_tuple = CloneNode(self.defaults_tuple)
709            fused_func.code_object = CloneNode(self.code_object)
710
711            for i, pycfunc in enumerate(self.specialized_pycfuncs):
712                pycfunc.code_object = CloneNode(self.code_object)
713                pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(env)
714                pycfunc.defaults_tuple = CloneNode(self.defaults_tuple)
715        return self
716
717    def synthesize_defnodes(self):
718        """
719        Create the __signatures__ dict of PyCFunctionNode specializations.
720        """
721        if isinstance(self.nodes[0], CFuncDefNode):
722            nodes = [node.py_func for node in self.nodes]
723        else:
724            nodes = self.nodes
725
726        signatures = [
727            StringEncoding.EncodedString(node.specialized_signature_string)
728                for node in nodes]
729        keys = [ExprNodes.StringNode(node.pos, value=sig)
730                    for node, sig in zip(nodes, signatures)]
731        values = [ExprNodes.PyCFunctionNode.from_defnode(node, True)
732                              for node in nodes]
733        self.__signatures__ = ExprNodes.DictNode.from_pairs(self.pos,
734                                                            zip(keys, values))
735
736        self.specialized_pycfuncs = values
737        for pycfuncnode in values:
738            pycfuncnode.is_specialization = True
739
740    def generate_function_definitions(self, env, code):
741        if self.py_func:
742            self.py_func.pymethdef_required = True
743            self.fused_func_assignment.generate_function_definitions(env, code)
744
745        for stat in self.stats:
746            if isinstance(stat, FuncDefNode) and stat.entry.used:
747                code.mark_pos(stat.pos)
748                stat.generate_function_definitions(env, code)
749
750    def generate_execution_code(self, code):
751        # Note: all def function specialization are wrapped in PyCFunction
752        # nodes in the self.__signatures__ dictnode.
753        for default in self.defaults:
754            if default is not None:
755                default.generate_evaluation_code(code)
756
757        if self.py_func:
758            self.defaults_tuple.generate_evaluation_code(code)
759            self.code_object.generate_evaluation_code(code)
760
761        for stat in self.stats:
762            code.mark_pos(stat.pos)
763            if isinstance(stat, ExprNodes.ExprNode):
764                stat.generate_evaluation_code(code)
765            else:
766                stat.generate_execution_code(code)
767
768        if self.__signatures__:
769            self.resulting_fused_function.generate_evaluation_code(code)
770
771            code.putln(
772                "((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" %
773                                    (self.resulting_fused_function.result(),
774                                     self.__signatures__.result()))
775            code.put_giveref(self.__signatures__.result())
776
777            self.fused_func_assignment.generate_execution_code(code)
778
779            # Dispose of results
780            self.resulting_fused_function.generate_disposal_code(code)
781            self.defaults_tuple.generate_disposal_code(code)
782            self.code_object.generate_disposal_code(code)
783
784        for default in self.defaults:
785            if default is not None:
786                default.generate_disposal_code(code)
787
788    def annotate(self, code):
789        for stat in self.stats:
790            stat.annotate(code)
791