1"""Flexible enumeration of C types."""
2
3from Enumeration import *
4
5# TODO:
6
7#  - struct improvements (flexible arrays, packed &
8#    unpacked, alignment)
9#  - objective-c qualified id
10#  - anonymous / transparent unions
11#  - VLAs
12#  - block types
13#  - K&R functions
14#  - pass arguments of different types (test extension, transparent union)
15#  - varargs
16
17###
18# Actual type types
19
20class Type:
21    def isBitField(self):
22        return False
23
24    def isPaddingBitField(self):
25        return False
26
27    def getTypeName(self, printer):
28        name = 'T%d' % len(printer.types)
29        typedef = self.getTypedefDef(name, printer)
30        printer.addDeclaration(typedef)
31        return name
32
33class BuiltinType(Type):
34    def __init__(self, name, size, bitFieldSize=None):
35        self.name = name
36        self.size = size
37        self.bitFieldSize = bitFieldSize
38
39    def isBitField(self):
40        return self.bitFieldSize is not None
41
42    def isPaddingBitField(self):
43        return self.bitFieldSize is 0
44
45    def getBitFieldSize(self):
46        assert self.isBitField()
47        return self.bitFieldSize
48
49    def getTypeName(self, printer):
50        return self.name
51
52    def sizeof(self):
53        return self.size
54
55    def __str__(self):
56        return self.name
57
58class EnumType(Type):
59    unique_id = 0
60
61    def __init__(self, index, enumerators):
62        self.index = index
63        self.enumerators = enumerators
64        self.unique_id = self.__class__.unique_id
65        self.__class__.unique_id += 1
66
67    def getEnumerators(self):
68        result = ''
69        for i, init in enumerate(self.enumerators):
70            if i > 0:
71                result = result + ', '
72            result = result + 'enum%dval%d_%d' % (self.index, i, self.unique_id)
73            if init:
74                result = result + ' = %s' % (init)
75
76        return result
77
78    def __str__(self):
79        return 'enum { %s }' % (self.getEnumerators())
80
81    def getTypedefDef(self, name, printer):
82        return 'typedef enum %s { %s } %s;'%(name, self.getEnumerators(), name)
83
84class RecordType(Type):
85    def __init__(self, index, isUnion, fields):
86        self.index = index
87        self.isUnion = isUnion
88        self.fields = fields
89        self.name = None
90
91    def __str__(self):
92        def getField(t):
93            if t.isBitField():
94                return "%s : %d;" % (t, t.getBitFieldSize())
95            else:
96                return "%s;" % t
97
98        return '%s { %s }'%(('struct','union')[self.isUnion],
99                            ' '.join(map(getField, self.fields)))
100
101    def getTypedefDef(self, name, printer):
102        def getField((i, t)):
103            if t.isBitField():
104                if t.isPaddingBitField():
105                    return '%s : 0;'%(printer.getTypeName(t),)
106                else:
107                    return '%s field%d : %d;'%(printer.getTypeName(t),i,
108                                               t.getBitFieldSize())
109            else:
110                return '%s field%d;'%(printer.getTypeName(t),i)
111        fields = map(getField, enumerate(self.fields))
112        # Name the struct for more readable LLVM IR.
113        return 'typedef %s %s { %s } %s;'%(('struct','union')[self.isUnion],
114                                           name, ' '.join(fields), name)
115
116class ArrayType(Type):
117    def __init__(self, index, isVector, elementType, size):
118        if isVector:
119            # Note that for vectors, this is the size in bytes.
120            assert size > 0
121        else:
122            assert size is None or size >= 0
123        self.index = index
124        self.isVector = isVector
125        self.elementType = elementType
126        self.size = size
127        if isVector:
128            eltSize = self.elementType.sizeof()
129            assert not (self.size % eltSize)
130            self.numElements = self.size // eltSize
131        else:
132            self.numElements = self.size
133
134    def __str__(self):
135        if self.isVector:
136            return 'vector (%s)[%d]'%(self.elementType,self.size)
137        elif self.size is not None:
138            return '(%s)[%d]'%(self.elementType,self.size)
139        else:
140            return '(%s)[]'%(self.elementType,)
141
142    def getTypedefDef(self, name, printer):
143        elementName = printer.getTypeName(self.elementType)
144        if self.isVector:
145            return 'typedef %s %s __attribute__ ((vector_size (%d)));'%(elementName,
146                                                                        name,
147                                                                        self.size)
148        else:
149            if self.size is None:
150                sizeStr = ''
151            else:
152                sizeStr = str(self.size)
153            return 'typedef %s %s[%s];'%(elementName, name, sizeStr)
154
155class ComplexType(Type):
156    def __init__(self, index, elementType):
157        self.index = index
158        self.elementType = elementType
159
160    def __str__(self):
161        return '_Complex (%s)'%(self.elementType)
162
163    def getTypedefDef(self, name, printer):
164        return 'typedef _Complex %s %s;'%(printer.getTypeName(self.elementType), name)
165
166class FunctionType(Type):
167    def __init__(self, index, returnType, argTypes):
168        self.index = index
169        self.returnType = returnType
170        self.argTypes = argTypes
171
172    def __str__(self):
173        if self.returnType is None:
174            rt = 'void'
175        else:
176            rt = str(self.returnType)
177        if not self.argTypes:
178            at = 'void'
179        else:
180            at = ', '.join(map(str, self.argTypes))
181        return '%s (*)(%s)'%(rt, at)
182
183    def getTypedefDef(self, name, printer):
184        if self.returnType is None:
185            rt = 'void'
186        else:
187            rt = str(self.returnType)
188        if not self.argTypes:
189            at = 'void'
190        else:
191            at = ', '.join(map(str, self.argTypes))
192        return 'typedef %s (*%s)(%s);'%(rt, name, at)
193
194###
195# Type enumerators
196
197class TypeGenerator(object):
198    def __init__(self):
199        self.cache = {}
200
201    def setCardinality(self):
202        abstract
203
204    def get(self, N):
205        T = self.cache.get(N)
206        if T is None:
207            assert 0 <= N < self.cardinality
208            T = self.cache[N] = self.generateType(N)
209        return T
210
211    def generateType(self, N):
212        abstract
213
214class FixedTypeGenerator(TypeGenerator):
215    def __init__(self, types):
216        TypeGenerator.__init__(self)
217        self.types = types
218        self.setCardinality()
219
220    def setCardinality(self):
221        self.cardinality = len(self.types)
222
223    def generateType(self, N):
224        return self.types[N]
225
226# Factorial
227def fact(n):
228    result = 1
229    while n > 0:
230        result = result * n
231        n = n - 1
232    return result
233
234# Compute the number of combinations (n choose k)
235def num_combinations(n, k):
236    return fact(n) / (fact(k) * fact(n - k))
237
238# Enumerate the combinations choosing k elements from the list of values
239def combinations(values, k):
240    # From ActiveState Recipe 190465: Generator for permutations,
241    # combinations, selections of a sequence
242    if k==0: yield []
243    else:
244        for i in xrange(len(values)-k+1):
245            for cc in combinations(values[i+1:],k-1):
246                yield [values[i]]+cc
247
248class EnumTypeGenerator(TypeGenerator):
249    def __init__(self, values, minEnumerators, maxEnumerators):
250        TypeGenerator.__init__(self)
251        self.values = values
252        self.minEnumerators = minEnumerators
253        self.maxEnumerators = maxEnumerators
254        self.setCardinality()
255
256    def setCardinality(self):
257        self.cardinality = 0
258        for num in range(self.minEnumerators, self.maxEnumerators + 1):
259            self.cardinality += num_combinations(len(self.values), num)
260
261    def generateType(self, n):
262        # Figure out the number of enumerators in this type
263        numEnumerators = self.minEnumerators
264        valuesCovered = 0
265        while numEnumerators < self.maxEnumerators:
266            comb = num_combinations(len(self.values), numEnumerators)
267            if valuesCovered + comb > n:
268                break
269            numEnumerators = numEnumerators + 1
270            valuesCovered += comb
271
272        # Find the requested combination of enumerators and build a
273        # type from it.
274        i = 0
275        for enumerators in combinations(self.values, numEnumerators):
276            if i == n - valuesCovered:
277                return EnumType(n, enumerators)
278
279            i = i + 1
280
281        assert False
282
283class ComplexTypeGenerator(TypeGenerator):
284    def __init__(self, typeGen):
285        TypeGenerator.__init__(self)
286        self.typeGen = typeGen
287        self.setCardinality()
288
289    def setCardinality(self):
290        self.cardinality = self.typeGen.cardinality
291
292    def generateType(self, N):
293        return ComplexType(N, self.typeGen.get(N))
294
295class VectorTypeGenerator(TypeGenerator):
296    def __init__(self, typeGen, sizes):
297        TypeGenerator.__init__(self)
298        self.typeGen = typeGen
299        self.sizes = tuple(map(int,sizes))
300        self.setCardinality()
301
302    def setCardinality(self):
303        self.cardinality = len(self.sizes)*self.typeGen.cardinality
304
305    def generateType(self, N):
306        S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
307        return ArrayType(N, True, self.typeGen.get(T), self.sizes[S])
308
309class FixedArrayTypeGenerator(TypeGenerator):
310    def __init__(self, typeGen, sizes):
311        TypeGenerator.__init__(self)
312        self.typeGen = typeGen
313        self.sizes = tuple(size)
314        self.setCardinality()
315
316    def setCardinality(self):
317        self.cardinality = len(self.sizes)*self.typeGen.cardinality
318
319    def generateType(self, N):
320        S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
321        return ArrayType(N, false, self.typeGen.get(T), self.sizes[S])
322
323class ArrayTypeGenerator(TypeGenerator):
324    def __init__(self, typeGen, maxSize, useIncomplete=False, useZero=False):
325        TypeGenerator.__init__(self)
326        self.typeGen = typeGen
327        self.useIncomplete = useIncomplete
328        self.useZero = useZero
329        self.maxSize = int(maxSize)
330        self.W = useIncomplete + useZero + self.maxSize
331        self.setCardinality()
332
333    def setCardinality(self):
334        self.cardinality = self.W * self.typeGen.cardinality
335
336    def generateType(self, N):
337        S,T = getNthPairBounded(N, self.W, self.typeGen.cardinality)
338        if self.useIncomplete:
339            if S==0:
340                size = None
341                S = None
342            else:
343                S = S - 1
344        if S is not None:
345            if self.useZero:
346                size = S
347            else:
348                size = S + 1
349        return ArrayType(N, False, self.typeGen.get(T), size)
350
351class RecordTypeGenerator(TypeGenerator):
352    def __init__(self, typeGen, useUnion, maxSize):
353        TypeGenerator.__init__(self)
354        self.typeGen = typeGen
355        self.useUnion = bool(useUnion)
356        self.maxSize = int(maxSize)
357        self.setCardinality()
358
359    def setCardinality(self):
360        M = 1 + self.useUnion
361        if self.maxSize is aleph0:
362            S =  aleph0 * self.typeGen.cardinality
363        else:
364            S = 0
365            for i in range(self.maxSize+1):
366                S += M * (self.typeGen.cardinality ** i)
367        self.cardinality = S
368
369    def generateType(self, N):
370        isUnion,I = False,N
371        if self.useUnion:
372            isUnion,I = (I&1),I>>1
373        fields = map(self.typeGen.get,getNthTuple(I,self.maxSize,self.typeGen.cardinality))
374        return RecordType(N, isUnion, fields)
375
376class FunctionTypeGenerator(TypeGenerator):
377    def __init__(self, typeGen, useReturn, maxSize):
378        TypeGenerator.__init__(self)
379        self.typeGen = typeGen
380        self.useReturn = useReturn
381        self.maxSize = maxSize
382        self.setCardinality()
383
384    def setCardinality(self):
385        if self.maxSize is aleph0:
386            S = aleph0 * self.typeGen.cardinality()
387        elif self.useReturn:
388            S = 0
389            for i in range(1,self.maxSize+1+1):
390                S += self.typeGen.cardinality ** i
391        else:
392            S = 0
393            for i in range(self.maxSize+1):
394                S += self.typeGen.cardinality ** i
395        self.cardinality = S
396
397    def generateType(self, N):
398        if self.useReturn:
399            # Skip the empty tuple
400            argIndices = getNthTuple(N+1, self.maxSize+1, self.typeGen.cardinality)
401            retIndex,argIndices = argIndices[0],argIndices[1:]
402            retTy = self.typeGen.get(retIndex)
403        else:
404            retTy = None
405            argIndices = getNthTuple(N, self.maxSize, self.typeGen.cardinality)
406        args = map(self.typeGen.get, argIndices)
407        return FunctionType(N, retTy, args)
408
409class AnyTypeGenerator(TypeGenerator):
410    def __init__(self):
411        TypeGenerator.__init__(self)
412        self.generators = []
413        self.bounds = []
414        self.setCardinality()
415        self._cardinality = None
416
417    def getCardinality(self):
418        if self._cardinality is None:
419            return aleph0
420        else:
421            return self._cardinality
422    def setCardinality(self):
423        self.bounds = [g.cardinality for g in self.generators]
424        self._cardinality = sum(self.bounds)
425    cardinality = property(getCardinality, None)
426
427    def addGenerator(self, g):
428        self.generators.append(g)
429        for i in range(100):
430            prev = self._cardinality
431            self._cardinality = None
432            for g in self.generators:
433                g.setCardinality()
434            self.setCardinality()
435            if (self._cardinality is aleph0) or prev==self._cardinality:
436                break
437        else:
438            raise RuntimeError,"Infinite loop in setting cardinality"
439
440    def generateType(self, N):
441        index,M = getNthPairVariableBounds(N, self.bounds)
442        return self.generators[index].get(M)
443
444def test():
445    fbtg = FixedTypeGenerator([BuiltinType('char', 4),
446                               BuiltinType('char', 4, 0),
447                               BuiltinType('int',  4, 5)])
448
449    fields1 = AnyTypeGenerator()
450    fields1.addGenerator( fbtg )
451
452    fields0 = AnyTypeGenerator()
453    fields0.addGenerator( fbtg )
454#    fields0.addGenerator( RecordTypeGenerator(fields1, False, 4) )
455
456    btg = FixedTypeGenerator([BuiltinType('char', 4),
457                              BuiltinType('int',  4)])
458    etg = EnumTypeGenerator([None, '-1', '1', '1u'], 0, 3)
459
460    atg = AnyTypeGenerator()
461    atg.addGenerator( btg )
462    atg.addGenerator( RecordTypeGenerator(fields0, False, 4) )
463    atg.addGenerator( etg )
464    print 'Cardinality:',atg.cardinality
465    for i in range(100):
466        if i == atg.cardinality:
467            try:
468                atg.get(i)
469                raise RuntimeError,"Cardinality was wrong"
470            except AssertionError:
471                break
472        print '%4d: %s'%(i, atg.get(i))
473
474if __name__ == '__main__':
475    test()
476