1from Cython.Compiler.Visitor import CythonTransform
2from Cython.Compiler.ModuleNode import ModuleNode
3from Cython.Compiler.Errors import CompileError
4from Cython.Compiler.UtilityCode import CythonUtilityCode
5from Cython.Compiler.Code import UtilityCode, TempitaUtilityCode
6
7from Cython.Compiler import Options
8from Cython.Compiler import Interpreter
9from Cython.Compiler import PyrexTypes
10from Cython.Compiler import Naming
11from Cython.Compiler import Symtab
12
13
14def dedent(text, reindent=0):
15    from textwrap import dedent
16    text = dedent(text)
17    if reindent > 0:
18        indent = " " * reindent
19        text = '\n'.join([indent + x for x in text.split('\n')])
20    return text
21
22class IntroduceBufferAuxiliaryVars(CythonTransform):
23
24    #
25    # Entry point
26    #
27
28    buffers_exists = False
29    using_memoryview = False
30
31    def __call__(self, node):
32        assert isinstance(node, ModuleNode)
33        self.max_ndim = 0
34        result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
35        if self.buffers_exists:
36            use_bufstruct_declare_code(node.scope)
37            use_py2_buffer_functions(node.scope)
38            node.scope.use_utility_code(empty_bufstruct_utility)
39
40        return result
41
42
43    #
44    # Basic operations for transforms
45    #
46    def handle_scope(self, node, scope):
47        # For all buffers, insert extra variables in the scope.
48        # The variables are also accessible from the buffer_info
49        # on the buffer entry
50        bufvars = [entry for name, entry
51                   in scope.entries.iteritems()
52                   if entry.type.is_buffer]
53        if len(bufvars) > 0:
54            bufvars.sort(key=lambda entry: entry.name)
55            self.buffers_exists = True
56
57        memviewslicevars = [entry for name, entry
58                in scope.entries.iteritems()
59                if entry.type.is_memoryviewslice]
60        if len(memviewslicevars) > 0:
61            self.buffers_exists = True
62
63
64        for (name, entry) in scope.entries.iteritems():
65            if name == 'memoryview' and isinstance(entry.utility_code_definition, CythonUtilityCode):
66                self.using_memoryview = True
67                break
68
69
70        if isinstance(node, ModuleNode) and len(bufvars) > 0:
71            # for now...note that pos is wrong
72            raise CompileError(node.pos, "Buffer vars not allowed in module scope")
73        for entry in bufvars:
74            if entry.type.dtype.is_ptr:
75                raise CompileError(node.pos, "Buffers with pointer types not yet supported.")
76
77            name = entry.name
78            buftype = entry.type
79            if buftype.ndim > Options.buffer_max_dims:
80                raise CompileError(node.pos,
81                        "Buffer ndims exceeds Options.buffer_max_dims = %d" % Options.buffer_max_dims)
82            if buftype.ndim > self.max_ndim:
83                self.max_ndim = buftype.ndim
84
85            # Declare auxiliary vars
86            def decvar(type, prefix):
87                cname = scope.mangle(prefix, name)
88                aux_var = scope.declare_var(name=None, cname=cname,
89                                            type=type, pos=node.pos)
90                if entry.is_arg:
91                    aux_var.used = True # otherwise, NameNode will mark whether it is used
92
93                return aux_var
94
95            auxvars = ((PyrexTypes.c_pyx_buffer_nd_type, Naming.pybuffernd_prefix),
96                       (PyrexTypes.c_pyx_buffer_type, Naming.pybufferstruct_prefix))
97            pybuffernd, rcbuffer = [decvar(type, prefix) for (type, prefix) in auxvars]
98
99            entry.buffer_aux = Symtab.BufferAux(pybuffernd, rcbuffer)
100
101        scope.buffer_entries = bufvars
102        self.scope = scope
103
104    def visit_ModuleNode(self, node):
105        self.handle_scope(node, node.scope)
106        self.visitchildren(node)
107        return node
108
109    def visit_FuncDefNode(self, node):
110        self.handle_scope(node, node.local_scope)
111        self.visitchildren(node)
112        return node
113
114#
115# Analysis
116#
117buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
118buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
119buffer_positional_options_count = 1 # anything beyond this needs keyword argument
120
121ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
122ERR_BUF_TOO_MANY = 'Too many buffer options'
123ERR_BUF_DUP = '"%s" buffer option already supplied'
124ERR_BUF_MISSING = '"%s" missing'
125ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
126ERR_BUF_NDIM = 'ndim must be a non-negative integer'
127ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
128ERR_BUF_BOOL = '"%s" must be a boolean'
129
130def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
131    """
132    Must be called during type analysis, as analyse is called
133    on the dtype argument.
134
135    posargs and dictargs should consist of a list and a dict
136    of tuples (value, pos). Defaults should be a dict of values.
137
138    Returns a dict containing all the options a buffer can have and
139    its value (with the positions stripped).
140    """
141    if defaults is None:
142        defaults = buffer_defaults
143
144    posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env, type_args = (0,'dtype'))
145
146    if len(posargs) > buffer_positional_options_count:
147        raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)
148
149    options = {}
150    for name, (value, pos) in dictargs.iteritems():
151        if not name in buffer_options:
152            raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
153        options[name] = value
154
155    for name, (value, pos) in zip(buffer_options, posargs):
156        if not name in buffer_options:
157            raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
158        if name in options:
159            raise CompileError(pos, ERR_BUF_DUP % name)
160        options[name] = value
161
162    # Check that they are all there and copy defaults
163    for name in buffer_options:
164        if not name in options:
165            try:
166                options[name] = defaults[name]
167            except KeyError:
168                if need_complete:
169                    raise CompileError(globalpos, ERR_BUF_MISSING % name)
170
171    dtype = options.get("dtype")
172    if dtype and dtype.is_extension_type:
173        raise CompileError(globalpos, ERR_BUF_DTYPE)
174
175    ndim = options.get("ndim")
176    if ndim and (not isinstance(ndim, int) or ndim < 0):
177        raise CompileError(globalpos, ERR_BUF_NDIM)
178
179    mode = options.get("mode")
180    if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
181        raise CompileError(globalpos, ERR_BUF_MODE)
182
183    def assert_bool(name):
184        x = options.get(name)
185        if not isinstance(x, bool):
186            raise CompileError(globalpos, ERR_BUF_BOOL % name)
187
188    assert_bool('negative_indices')
189    assert_bool('cast')
190
191    return options
192
193
194#
195# Code generation
196#
197
198class BufferEntry(object):
199    def __init__(self, entry):
200        self.entry = entry
201        self.type = entry.type
202        self.cname = entry.buffer_aux.buflocal_nd_var.cname
203        self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname
204        self.buf_ptr_type = self.entry.type.buffer_ptr_type
205
206    def get_buf_suboffsetvars(self):
207        return self._for_all_ndim("%s.diminfo[%d].suboffsets")
208
209    def get_buf_stridevars(self):
210        return self._for_all_ndim("%s.diminfo[%d].strides")
211
212    def get_buf_shapevars(self):
213        return self._for_all_ndim("%s.diminfo[%d].shape")
214
215    def _for_all_ndim(self, s):
216        return [s % (self.cname, i) for i in range(self.type.ndim)]
217
218    def generate_buffer_lookup_code(self, code, index_cnames):
219        # Create buffer lookup and return it
220        # This is done via utility macros/inline functions, which vary
221        # according to the access mode used.
222        params = []
223        nd = self.type.ndim
224        mode = self.type.mode
225        if mode == 'full':
226            for i, s, o in zip(index_cnames,
227                               self.get_buf_stridevars(),
228                               self.get_buf_suboffsetvars()):
229                params.append(i)
230                params.append(s)
231                params.append(o)
232            funcname = "__Pyx_BufPtrFull%dd" % nd
233            funcgen = buf_lookup_full_code
234        else:
235            if mode == 'strided':
236                funcname = "__Pyx_BufPtrStrided%dd" % nd
237                funcgen = buf_lookup_strided_code
238            elif mode == 'c':
239                funcname = "__Pyx_BufPtrCContig%dd" % nd
240                funcgen = buf_lookup_c_code
241            elif mode == 'fortran':
242                funcname = "__Pyx_BufPtrFortranContig%dd" % nd
243                funcgen = buf_lookup_fortran_code
244            else:
245                assert False
246            for i, s in zip(index_cnames, self.get_buf_stridevars()):
247                params.append(i)
248                params.append(s)
249
250        # Make sure the utility code is available
251        if funcname not in code.globalstate.utility_codes:
252            code.globalstate.utility_codes.add(funcname)
253            protocode = code.globalstate['utility_code_proto']
254            defcode = code.globalstate['utility_code_def']
255            funcgen(protocode, defcode, name=funcname, nd=nd)
256
257        buf_ptr_type_code = self.buf_ptr_type.declaration_code("")
258        ptrcode = "%s(%s, %s, %s)" % (funcname, buf_ptr_type_code, self.buf_ptr,
259                                      ", ".join(params))
260        return ptrcode
261
262
263def get_flags(buffer_aux, buffer_type):
264    flags = 'PyBUF_FORMAT'
265    mode = buffer_type.mode
266    if mode == 'full':
267        flags += '| PyBUF_INDIRECT'
268    elif mode == 'strided':
269        flags += '| PyBUF_STRIDES'
270    elif mode == 'c':
271        flags += '| PyBUF_C_CONTIGUOUS'
272    elif mode == 'fortran':
273        flags += '| PyBUF_F_CONTIGUOUS'
274    else:
275        assert False
276    if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
277    return flags
278
279def used_buffer_aux_vars(entry):
280    buffer_aux = entry.buffer_aux
281    buffer_aux.buflocal_nd_var.used = True
282    buffer_aux.rcbuf_var.used = True
283
284def put_unpack_buffer_aux_into_scope(buf_entry, code):
285    # Generate code to copy the needed struct info into local
286    # variables.
287    buffer_aux, mode = buf_entry.buffer_aux, buf_entry.type.mode
288    pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
289
290    fldnames = ['strides', 'shape']
291    if mode == 'full':
292        fldnames.append('suboffsets')
293
294    ln = []
295    for i in range(buf_entry.type.ndim):
296        for fldname in fldnames:
297            ln.append("%s.diminfo[%d].%s = %s.rcbuffer->pybuffer.%s[%d];" % \
298                    (pybuffernd_struct, i, fldname,
299                     pybuffernd_struct, fldname, i))
300    code.putln(' '.join(ln))
301
302def put_init_vars(entry, code):
303    bufaux = entry.buffer_aux
304    pybuffernd_struct = bufaux.buflocal_nd_var.cname
305    pybuffer_struct = bufaux.rcbuf_var.cname
306    # init pybuffer_struct
307    code.putln("%s.pybuffer.buf = NULL;" % pybuffer_struct)
308    code.putln("%s.refcount = 0;" % pybuffer_struct)
309    # init the buffer object
310    # code.put_init_var_to_py_none(entry)
311    # init the pybuffernd_struct
312    code.putln("%s.data = NULL;" % pybuffernd_struct)
313    code.putln("%s.rcbuffer = &%s;" % (pybuffernd_struct, pybuffer_struct))
314
315def put_acquire_arg_buffer(entry, code, pos):
316    code.globalstate.use_utility_code(acquire_utility_code)
317    buffer_aux = entry.buffer_aux
318    getbuffer = get_getbuffer_call(code, entry.cname, buffer_aux, entry.type)
319
320    # Acquire any new buffer
321    code.putln("{")
322    code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % entry.type.dtype.struct_nesting_depth())
323    code.putln(code.error_goto_if("%s == -1" % getbuffer, pos))
324    code.putln("}")
325    # An exception raised in arg parsing cannot be catched, so no
326    # need to care about the buffer then.
327    put_unpack_buffer_aux_into_scope(entry, code)
328
329def put_release_buffer_code(code, entry):
330    code.globalstate.use_utility_code(acquire_utility_code)
331    code.putln("__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);" % entry.buffer_aux.buflocal_nd_var.cname)
332
333def get_getbuffer_call(code, obj_cname, buffer_aux, buffer_type):
334    ndim = buffer_type.ndim
335    cast = int(buffer_type.cast)
336    flags = get_flags(buffer_aux, buffer_type)
337    pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
338
339    dtype_typeinfo = get_type_information_cname(code, buffer_type.dtype)
340
341    return ("__Pyx_GetBufferAndValidate(&%(pybuffernd_struct)s.rcbuffer->pybuffer, "
342            "(PyObject*)%(obj_cname)s, &%(dtype_typeinfo)s, %(flags)s, %(ndim)d, "
343            "%(cast)d, __pyx_stack)" % locals())
344
345def put_assign_to_buffer(lhs_cname, rhs_cname, buf_entry,
346                         is_initialized, pos, code):
347    """
348    Generate code for reassigning a buffer variables. This only deals with getting
349    the buffer auxiliary structure and variables set up correctly, the assignment
350    itself and refcounting is the responsibility of the caller.
351
352    However, the assignment operation may throw an exception so that the reassignment
353    never happens.
354
355    Depending on the circumstances there are two possible outcomes:
356    - Old buffer released, new acquired, rhs assigned to lhs
357    - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
358      (which may or may not succeed).
359    """
360
361    buffer_aux, buffer_type = buf_entry.buffer_aux, buf_entry.type
362    code.globalstate.use_utility_code(acquire_utility_code)
363    pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
364    flags = get_flags(buffer_aux, buffer_type)
365
366    code.putln("{")  # Set up necesarry stack for getbuffer
367    code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % buffer_type.dtype.struct_nesting_depth())
368
369    getbuffer = get_getbuffer_call(code, "%s", buffer_aux, buffer_type) # fill in object below
370
371    if is_initialized:
372        # Release any existing buffer
373        code.putln('__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);' % pybuffernd_struct)
374        # Acquire
375        retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
376        code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
377        code.putln('if (%s) {' % (code.unlikely("%s < 0" % retcode_cname)))
378        # If acquisition failed, attempt to reacquire the old buffer
379        # before raising the exception. A failure of reacquisition
380        # will cause the reacquisition exception to be reported, one
381        # can consider working around this later.
382        type, value, tb = [code.funcstate.allocate_temp(PyrexTypes.py_object_type, manage_ref=False)
383                           for i in range(3)]
384        code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
385        code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
386        code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb)) # Do not refnanny these!
387        code.globalstate.use_utility_code(raise_buffer_fallback_code)
388        code.putln('__Pyx_RaiseBufferFallbackError();')
389        code.putln('} else {')
390        code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
391        for t in (type, value, tb):
392            code.funcstate.release_temp(t)
393        code.putln('}')
394        code.putln('}')
395        # Unpack indices
396        put_unpack_buffer_aux_into_scope(buf_entry, code)
397        code.putln(code.error_goto_if_neg(retcode_cname, pos))
398        code.funcstate.release_temp(retcode_cname)
399    else:
400        # Our entry had no previous value, so set to None when acquisition fails.
401        # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
402        # so it suffices to set the buf field to NULL.
403        code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
404        code.putln('%s = %s; __Pyx_INCREF(Py_None); %s.rcbuffer->pybuffer.buf = NULL;' %
405                   (lhs_cname,
406                    PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
407                    pybuffernd_struct))
408        code.putln(code.error_goto(pos))
409        code.put('} else {')
410        # Unpack indices
411        put_unpack_buffer_aux_into_scope(buf_entry, code)
412        code.putln('}')
413
414    code.putln("}") # Release stack
415
416def put_buffer_lookup_code(entry, index_signeds, index_cnames, directives,
417                           pos, code, negative_indices, in_nogil_context):
418    """
419    Generates code to process indices and calculate an offset into
420    a buffer. Returns a C string which gives a pointer which can be
421    read from or written to at will (it is an expression so caller should
422    store it in a temporary if it is used more than once).
423
424    As the bounds checking can have any number of combinations of unsigned
425    arguments, smart optimizations etc. we insert it directly in the function
426    body. The lookup however is delegated to a inline function that is instantiated
427    once per ndim (lookup with suboffsets tend to get quite complicated).
428
429    entry is a BufferEntry
430    """
431    negative_indices = directives['wraparound'] and negative_indices
432
433    if directives['boundscheck']:
434        # Check bounds and fix negative indices.
435        # We allocate a temporary which is initialized to -1, meaning OK (!).
436        # If an error occurs, the temp is set to the dimension index the
437        # error is occuring at.
438        tmp_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
439        code.putln("%s = -1;" % tmp_cname)
440        for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
441                                                         entry.get_buf_shapevars())):
442            if signed != 0:
443                # not unsigned, deal with negative index
444                code.putln("if (%s < 0) {" % cname)
445                if negative_indices:
446                    code.putln("%s += %s;" % (cname, shape))
447                    code.putln("if (%s) %s = %d;" % (
448                        code.unlikely("%s < 0" % cname), tmp_cname, dim))
449                else:
450                    code.putln("%s = %d;" % (tmp_cname, dim))
451                code.put("} else ")
452            # check bounds in positive direction
453            if signed != 0:
454                cast = ""
455            else:
456                cast = "(size_t)"
457            code.putln("if (%s) %s = %d;" % (
458                code.unlikely("%s >= %s%s" % (cname, cast, shape)),
459                              tmp_cname, dim))
460
461        if in_nogil_context:
462            code.globalstate.use_utility_code(raise_indexerror_nogil)
463            func = '__Pyx_RaiseBufferIndexErrorNogil'
464        else:
465            code.globalstate.use_utility_code(raise_indexerror_code)
466            func = '__Pyx_RaiseBufferIndexError'
467
468        code.putln("if (%s) {" % code.unlikely("%s != -1" % tmp_cname))
469        code.putln('%s(%s);' % (func, tmp_cname))
470        code.putln(code.error_goto(pos))
471        code.putln('}')
472        code.funcstate.release_temp(tmp_cname)
473    elif negative_indices:
474        # Only fix negative indices.
475        for signed, cname, shape in zip(index_signeds, index_cnames,
476                                        entry.get_buf_shapevars()):
477            if signed != 0:
478                code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape))
479
480    return entry.generate_buffer_lookup_code(code, index_cnames)
481
482
483def use_bufstruct_declare_code(env):
484    env.use_utility_code(buffer_struct_declare_code)
485
486
487def get_empty_bufstruct_code(max_ndim):
488    code = dedent("""
489        static Py_ssize_t __Pyx_zeros[] = {%s};
490        static Py_ssize_t __Pyx_minusones[] = {%s};
491    """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
492    return UtilityCode(proto=code)
493
494empty_bufstruct_utility = get_empty_bufstruct_code(Options.buffer_max_dims)
495
496def buf_lookup_full_code(proto, defin, name, nd):
497    """
498    Generates a buffer lookup function for the right number
499    of dimensions. The function gives back a void* at the right location.
500    """
501    # _i_ndex, _s_tride, sub_o_ffset
502    macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
503    proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))
504
505    funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
506    proto.putln("static CYTHON_INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
507    defin.putln(dedent("""
508        static CYTHON_INLINE void* %s_imp(void* buf, %s) {
509          char* ptr = (char*)buf;
510        """) % (name, funcargs) + "".join([dedent("""\
511          ptr += s%d * i%d;
512          if (o%d >= 0) ptr = *((char**)ptr) + o%d;
513        """) % (i, i, i, i) for i in range(nd)]
514        ) + "\nreturn ptr;\n}")
515
516def buf_lookup_strided_code(proto, defin, name, nd):
517    """
518    Generates a buffer lookup function for the right number
519    of dimensions. The function gives back a void* at the right location.
520    """
521    # _i_ndex, _s_tride
522    args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
523    offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
524    proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))
525
526def buf_lookup_c_code(proto, defin, name, nd):
527    """
528    Similar to strided lookup, but can assume that the last dimension
529    doesn't need a multiplication as long as.
530    Still we keep the same signature for now.
531    """
532    if nd == 1:
533        proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
534    else:
535        args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
536        offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
537        proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))
538
539def buf_lookup_fortran_code(proto, defin, name, nd):
540    """
541    Like C lookup, but the first index is optimized instead.
542    """
543    if nd == 1:
544        proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
545    else:
546        args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
547        offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
548        proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
549
550
551def use_py2_buffer_functions(env):
552    env.use_utility_code(GetAndReleaseBufferUtilityCode())
553
554class GetAndReleaseBufferUtilityCode(object):
555    # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
556    # For >= 2.6 we do double mode -- use the new buffer interface on objects
557    # which has the right tp_flags set, but emulation otherwise.
558
559    requires = None
560    is_cython_utility = False
561
562    def __init__(self):
563        pass
564
565    def __eq__(self, other):
566        return isinstance(other, GetAndReleaseBufferUtilityCode)
567
568    def __hash__(self):
569        return 24342342
570
571    def get_tree(self): pass
572
573    def put_code(self, output):
574        code = output['utility_code_def']
575        proto_code = output['utility_code_proto']
576        env = output.module_node.scope
577        cython_scope = env.context.cython_scope
578
579        # Search all types for __getbuffer__ overloads
580        types = []
581        visited_scopes = set()
582        def find_buffer_types(scope):
583            if scope in visited_scopes:
584                return
585            visited_scopes.add(scope)
586            for m in scope.cimported_modules:
587                find_buffer_types(m)
588            for e in scope.type_entries:
589                if isinstance(e.utility_code_definition, CythonUtilityCode):
590                    continue
591                t = e.type
592                if t.is_extension_type:
593                    if scope is cython_scope and not e.used:
594                        continue
595                    release = get = None
596                    for x in t.scope.pyfunc_entries:
597                        if x.name == u"__getbuffer__": get = x.func_cname
598                        elif x.name == u"__releasebuffer__": release = x.func_cname
599                    if get:
600                        types.append((t.typeptr_cname, get, release))
601
602        find_buffer_types(env)
603
604        util_code = TempitaUtilityCode.load(
605            "GetAndReleaseBuffer", from_file="Buffer.c",
606            context=dict(types=types))
607
608        proto = util_code.format_code(util_code.proto)
609        impl = util_code.format_code(
610            util_code.inject_string_constants(util_code.impl, output)[1])
611
612        proto_code.putln(proto)
613        code.putln(impl)
614
615
616def mangle_dtype_name(dtype):
617    # Use prefixes to seperate user defined types from builtins
618    # (consider "typedef float unsigned_int")
619    if dtype.is_pyobject:
620        return "object"
621    elif dtype.is_ptr:
622        return "ptr"
623    else:
624        if dtype.is_typedef or dtype.is_struct_or_union:
625            prefix = "nn_"
626        else:
627            prefix = ""
628        type_decl = dtype.declaration_code("")
629        type_decl = type_decl.replace(" ", "_")
630        return prefix + type_decl.replace("[", "_").replace("]", "_")
631
632def get_type_information_cname(code, dtype, maxdepth=None):
633    """
634    Output the run-time type information (__Pyx_TypeInfo) for given dtype,
635    and return the name of the type info struct.
636
637    Structs with two floats of the same size are encoded as complex numbers.
638    One can seperate between complex numbers declared as struct or with native
639    encoding by inspecting to see if the fields field of the type is
640    filled in.
641    """
642    namesuffix = mangle_dtype_name(dtype)
643    name = "__Pyx_TypeInfo_%s" % namesuffix
644    structinfo_name = "__Pyx_StructFields_%s" % namesuffix
645
646    if dtype.is_error: return "<error>"
647
648    # It's critical that walking the type info doesn't use more stack
649    # depth than dtype.struct_nesting_depth() returns, so use an assertion for this
650    if maxdepth is None: maxdepth = dtype.struct_nesting_depth()
651    if maxdepth <= 0:
652        assert False
653
654    if name not in code.globalstate.utility_codes:
655        code.globalstate.utility_codes.add(name)
656        typecode = code.globalstate['typeinfo']
657
658        arraysizes = []
659        if dtype.is_array:
660            while dtype.is_array:
661                arraysizes.append(dtype.size)
662                dtype = dtype.base_type
663
664        complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
665
666        declcode = dtype.declaration_code("")
667        if dtype.is_simple_buffer_dtype():
668            structinfo_name = "NULL"
669        elif dtype.is_struct:
670            fields = dtype.scope.var_entries
671            # Must pre-call all used types in order not to recurse utility code
672            # writing.
673            assert len(fields) > 0
674            types = [get_type_information_cname(code, f.type, maxdepth - 1)
675                     for f in fields]
676            typecode.putln("static __Pyx_StructField %s[] = {" % structinfo_name, safe=True)
677            for f, typeinfo in zip(fields, types):
678                typecode.putln('  {&%s, "%s", offsetof(%s, %s)},' %
679                           (typeinfo, f.name, dtype.declaration_code(""), f.cname), safe=True)
680            typecode.putln('  {NULL, NULL, 0}', safe=True)
681            typecode.putln("};", safe=True)
682        else:
683            assert False
684
685        rep = str(dtype)
686
687        flags = "0"
688        is_unsigned = "0"
689        if dtype is PyrexTypes.c_char_type:
690            is_unsigned = "IS_UNSIGNED(%s)" % declcode
691            typegroup = "'H'"
692        elif dtype.is_int:
693            is_unsigned = "IS_UNSIGNED(%s)" % declcode
694            typegroup = "%s ? 'U' : 'I'" % is_unsigned
695        elif complex_possible or dtype.is_complex:
696            typegroup = "'C'"
697        elif dtype.is_float:
698            typegroup = "'R'"
699        elif dtype.is_struct:
700            typegroup = "'S'"
701            if dtype.packed:
702                flags = "__PYX_BUF_FLAGS_PACKED_STRUCT"
703        elif dtype.is_pyobject:
704            typegroup = "'O'"
705        else:
706            assert False, dtype
707
708        typeinfo = ('static __Pyx_TypeInfo %s = '
709                        '{ "%s", %s, sizeof(%s), { %s }, %s, %s, %s, %s };')
710        tup = (name, rep, structinfo_name, declcode,
711               ', '.join([str(x) for x in arraysizes]) or '0', len(arraysizes),
712               typegroup, is_unsigned, flags)
713        typecode.putln(typeinfo % tup, safe=True)
714
715    return name
716
717def load_buffer_utility(util_code_name, context=None, **kwargs):
718    if context is None:
719        return UtilityCode.load(util_code_name, "Buffer.c", **kwargs)
720    else:
721        return TempitaUtilityCode.load(util_code_name, "Buffer.c", context=context, **kwargs)
722
723context = dict(max_dims=str(Options.buffer_max_dims))
724buffer_struct_declare_code = load_buffer_utility("BufferStructDeclare",
725                                                 context=context)
726
727
728# Utility function to set the right exception
729# The caller should immediately goto_error
730raise_indexerror_code = load_buffer_utility("BufferIndexError")
731raise_indexerror_nogil = load_buffer_utility("BufferIndexErrorNogil")
732
733raise_buffer_fallback_code = load_buffer_utility("BufferFallbackError")
734buffer_structs_code = load_buffer_utility(
735        "BufferFormatStructs", proto_block='utility_code_proto_before_types')
736acquire_utility_code = load_buffer_utility("BufferFormatCheck",
737                                           context=context,
738                                           requires=[buffer_structs_code])
739
740# See utility code BufferFormatFromTypeInfo
741_typeinfo_to_format_code = load_buffer_utility("TypeInfoToFormat", context={},
742                                               requires=[buffer_structs_code])
743typeinfo_compare_code = load_buffer_utility("TypeInfoCompare", context={},
744                                            requires=[buffer_structs_code])
745