TypeGen.py revision 43ec96da71b94f064ed6a3031025208d23acd592
1"""Flexible enumeration of C types."""
2
3from Enumeration import *
4
5# TODO:
6
7#  - struct improvements (flexible arrays, packed &
8#    unpacked, alignment)
9#  - objective-c qualified id
10#  - anonymous / transparent unions
11#  - VLAs
12#  - block types
13#  - K&R functions
14#  - pass arguments of different types (test extension, transparent union)
15#  - varargs
16
17###
18# Actual type types
19
20class Type:
21    def isBitField(self):
22        return False
23
24    def isPaddingBitField(self):
25        return False
26
27class BuiltinType(Type):
28    def __init__(self, name, size, bitFieldSize=None):
29        self.name = name
30        self.size = size
31        self.bitFieldSize = bitFieldSize
32
33    def isBitField(self):
34        return self.bitFieldSize is not None
35
36    def isPaddingBitField(self):
37        return self.bitFieldSize is 0
38
39    def getBitFieldSize(self):
40        assert self.isBitField()
41        return self.bitFieldSize
42
43    def sizeof(self):
44        return self.size
45
46    def __str__(self):
47        return self.name
48
49class EnumType(Type):
50    def __init__(self, index, enumerators):
51        self.index = index
52        self.enumerators = enumerators
53
54    def getEnumerators(self):
55        result = ''
56        for i, init in enumerate(self.enumerators):
57            if i > 0:
58                result = result + ', '
59            result = result + 'enum%dval%d' % (self.index, i)
60            if init:
61                result = result + ' = %s' % (init)
62
63        return result
64
65    def __str__(self):
66        return 'enum { %s }' % (self.getEnumerators())
67
68    def getTypedefDef(self, name, printer):
69        return 'typedef enum %s { %s } %s;'%(name, self.getEnumerators(), name)
70
71class RecordType(Type):
72    def __init__(self, index, isUnion, fields):
73        self.index = index
74        self.isUnion = isUnion
75        self.fields = fields
76        self.name = None
77
78    def __str__(self):
79        def getField(t):
80            if t.isBitField():
81                return "%s : %d;" % (t, t.getBitFieldSize())
82            else:
83                return "%s;" % t
84
85        return '%s { %s }'%(('struct','union')[self.isUnion],
86                            ' '.join(map(getField, self.fields)))
87
88    def getTypedefDef(self, name, printer):
89        def getField((i, t)):
90            if t.isBitField():
91                if t.isPaddingBitField():
92                    return '%s : 0;'%(printer.getTypeName(t),)
93                else:
94                    return '%s field%d : %d;'%(printer.getTypeName(t),i,
95                                               t.getBitFieldSize())
96            else:
97                return '%s field%d;'%(printer.getTypeName(t),i)
98        fields = map(getField, enumerate(self.fields))
99        # Name the struct for more readable LLVM IR.
100        return 'typedef %s %s { %s } %s;'%(('struct','union')[self.isUnion],
101                                           name, ' '.join(fields), name)
102
103class ArrayType(Type):
104    def __init__(self, index, isVector, elementType, size):
105        if isVector:
106            # Note that for vectors, this is the size in bytes.
107            assert size > 0
108        else:
109            assert size is None or size >= 0
110        self.index = index
111        self.isVector = isVector
112        self.elementType = elementType
113        self.size = size
114        if isVector:
115            eltSize = self.elementType.sizeof()
116            assert not (self.size % eltSize)
117            self.numElements = self.size // eltSize
118        else:
119            self.numElements = self.size
120
121    def __str__(self):
122        if self.isVector:
123            return 'vector (%s)[%d]'%(self.elementType,self.size)
124        elif self.size is not None:
125            return '(%s)[%d]'%(self.elementType,self.size)
126        else:
127            return '(%s)[]'%(self.elementType,)
128
129    def getTypedefDef(self, name, printer):
130        elementName = printer.getTypeName(self.elementType)
131        if self.isVector:
132            return 'typedef %s %s __attribute__ ((vector_size (%d)));'%(elementName,
133                                                                        name,
134                                                                        self.size)
135        else:
136            if self.size is None:
137                sizeStr = ''
138            else:
139                sizeStr = str(self.size)
140            return 'typedef %s %s[%s];'%(elementName, name, sizeStr)
141
142class ComplexType(Type):
143    def __init__(self, index, elementType):
144        self.index = index
145        self.elementType = elementType
146
147    def __str__(self):
148        return '_Complex (%s)'%(self.elementType)
149
150    def getTypedefDef(self, name, printer):
151        return 'typedef _Complex %s %s;'%(printer.getTypeName(self.elementType), name)
152
153class FunctionType(Type):
154    def __init__(self, index, returnType, argTypes):
155        self.index = index
156        self.returnType = returnType
157        self.argTypes = argTypes
158
159    def __str__(self):
160        if self.returnType is None:
161            rt = 'void'
162        else:
163            rt = str(self.returnType)
164        if not self.argTypes:
165            at = 'void'
166        else:
167            at = ', '.join(map(str, self.argTypes))
168        return '%s (*)(%s)'%(rt, at)
169
170    def getTypedefDef(self, name, printer):
171        if self.returnType is None:
172            rt = 'void'
173        else:
174            rt = str(self.returnType)
175        if not self.argTypes:
176            at = 'void'
177        else:
178            at = ', '.join(map(str, self.argTypes))
179        return 'typedef %s (*%s)(%s);'%(rt, name, at)
180
181###
182# Type enumerators
183
184class TypeGenerator(object):
185    def __init__(self):
186        self.cache = {}
187
188    def setCardinality(self):
189        abstract
190
191    def get(self, N):
192        T = self.cache.get(N)
193        if T is None:
194            assert 0 <= N < self.cardinality
195            T = self.cache[N] = self.generateType(N)
196        return T
197
198    def generateType(self, N):
199        abstract
200
201class FixedTypeGenerator(TypeGenerator):
202    def __init__(self, types):
203        TypeGenerator.__init__(self)
204        self.types = types
205        self.setCardinality()
206
207    def setCardinality(self):
208        self.cardinality = len(self.types)
209
210    def generateType(self, N):
211        return self.types[N]
212
213# Factorial
214def fact(n):
215    result = 1
216    while n > 0:
217        result = result * n
218        n = n - 1
219    return result
220
221# Compute the number of combinations (n choose k)
222def num_combinations(n, k):
223    return fact(n) / (fact(k) * fact(n - k))
224
225# Enumerate the combinations choosing k elements from the list of values
226def combinations(values, k):
227    # From ActiveState Recipe 190465: Generator for permutations,
228    # combinations, selections of a sequence
229    if k==0: yield []
230    else:
231        for i in xrange(len(values)-k+1):
232            for cc in combinations(values[i+1:],k-1):
233                yield [values[i]]+cc
234
235class EnumTypeGenerator(TypeGenerator):
236    def __init__(self, values, minEnumerators, maxEnumerators):
237        TypeGenerator.__init__(self)
238        self.values = values
239        self.minEnumerators = minEnumerators
240        self.maxEnumerators = maxEnumerators
241        self.setCardinality()
242
243    def setCardinality(self):
244        self.cardinality = 0
245        for num in range(self.minEnumerators, self.maxEnumerators + 1):
246            self.cardinality += num_combinations(len(self.values), num)
247
248    def generateType(self, n):
249        # Figure out the number of enumerators in this type
250        numEnumerators = self.minEnumerators
251        valuesCovered = 0
252        while numEnumerators < self.maxEnumerators:
253            comb = num_combinations(len(self.values), numEnumerators)
254            if valuesCovered + comb > n:
255                break
256            numEnumerators = numEnumerators + 1
257            valuesCovered += comb
258
259        # Find the requested combination of enumerators and build a
260        # type from it.
261        i = 0
262        for enumerators in combinations(self.values, numEnumerators):
263            if i == n - valuesCovered:
264                return EnumType(n, enumerators)
265
266            i = i + 1
267
268        assert False
269
270class ComplexTypeGenerator(TypeGenerator):
271    def __init__(self, typeGen):
272        TypeGenerator.__init__(self)
273        self.typeGen = typeGen
274        self.setCardinality()
275
276    def setCardinality(self):
277        self.cardinality = self.typeGen.cardinality
278
279    def generateType(self, N):
280        return ComplexType(N, self.typeGen.get(N))
281
282class VectorTypeGenerator(TypeGenerator):
283    def __init__(self, typeGen, sizes):
284        TypeGenerator.__init__(self)
285        self.typeGen = typeGen
286        self.sizes = tuple(map(int,sizes))
287        self.setCardinality()
288
289    def setCardinality(self):
290        self.cardinality = len(self.sizes)*self.typeGen.cardinality
291
292    def generateType(self, N):
293        S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
294        return ArrayType(N, True, self.typeGen.get(T), self.sizes[S])
295
296class FixedArrayTypeGenerator(TypeGenerator):
297    def __init__(self, typeGen, sizes):
298        TypeGenerator.__init__(self)
299        self.typeGen = typeGen
300        self.sizes = tuple(size)
301        self.setCardinality()
302
303    def setCardinality(self):
304        self.cardinality = len(self.sizes)*self.typeGen.cardinality
305
306    def generateType(self, N):
307        S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
308        return ArrayType(N, false, self.typeGen.get(T), self.sizes[S])
309
310class ArrayTypeGenerator(TypeGenerator):
311    def __init__(self, typeGen, maxSize, useIncomplete=False, useZero=False):
312        TypeGenerator.__init__(self)
313        self.typeGen = typeGen
314        self.useIncomplete = useIncomplete
315        self.useZero = useZero
316        self.maxSize = int(maxSize)
317        self.W = useIncomplete + useZero + self.maxSize
318        self.setCardinality()
319
320    def setCardinality(self):
321        self.cardinality = self.W * self.typeGen.cardinality
322
323    def generateType(self, N):
324        S,T = getNthPairBounded(N, self.W, self.typeGen.cardinality)
325        if self.useIncomplete:
326            if S==0:
327                size = None
328                S = None
329            else:
330                S = S - 1
331        if S is not None:
332            if self.useZero:
333                size = S
334            else:
335                size = S + 1
336        return ArrayType(N, False, self.typeGen.get(T), size)
337
338class RecordTypeGenerator(TypeGenerator):
339    def __init__(self, typeGen, useUnion, maxSize):
340        TypeGenerator.__init__(self)
341        self.typeGen = typeGen
342        self.useUnion = bool(useUnion)
343        self.maxSize = int(maxSize)
344        self.setCardinality()
345
346    def setCardinality(self):
347        M = 1 + self.useUnion
348        if self.maxSize is aleph0:
349            S =  aleph0 * self.typeGen.cardinality
350        else:
351            S = 0
352            for i in range(self.maxSize+1):
353                S += M * (self.typeGen.cardinality ** i)
354        self.cardinality = S
355
356    def generateType(self, N):
357        isUnion,I = False,N
358        if self.useUnion:
359            isUnion,I = (I&1),I>>1
360        fields = map(self.typeGen.get,getNthTuple(I,self.maxSize,self.typeGen.cardinality))
361        return RecordType(N, isUnion, fields)
362
363class FunctionTypeGenerator(TypeGenerator):
364    def __init__(self, typeGen, useReturn, maxSize):
365        TypeGenerator.__init__(self)
366        self.typeGen = typeGen
367        self.useReturn = useReturn
368        self.maxSize = maxSize
369        self.setCardinality()
370
371    def setCardinality(self):
372        if self.maxSize is aleph0:
373            S = aleph0 * self.typeGen.cardinality()
374        elif self.useReturn:
375            S = 0
376            for i in range(1,self.maxSize+1+1):
377                S += self.typeGen.cardinality ** i
378        else:
379            S = 0
380            for i in range(self.maxSize+1):
381                S += self.typeGen.cardinality ** i
382        self.cardinality = S
383
384    def generateType(self, N):
385        if self.useReturn:
386            # Skip the empty tuple
387            argIndices = getNthTuple(N+1, self.maxSize+1, self.typeGen.cardinality)
388            retIndex,argIndices = argIndices[0],argIndices[1:]
389            retTy = self.typeGen.get(retIndex)
390        else:
391            retTy = None
392            argIndices = getNthTuple(N, self.maxSize, self.typeGen.cardinality)
393        args = map(self.typeGen.get, argIndices)
394        return FunctionType(N, retTy, args)
395
396class AnyTypeGenerator(TypeGenerator):
397    def __init__(self):
398        TypeGenerator.__init__(self)
399        self.generators = []
400        self.bounds = []
401        self.setCardinality()
402        self._cardinality = None
403
404    def getCardinality(self):
405        if self._cardinality is None:
406            return aleph0
407        else:
408            return self._cardinality
409    def setCardinality(self):
410        self.bounds = [g.cardinality for g in self.generators]
411        self._cardinality = sum(self.bounds)
412    cardinality = property(getCardinality, None)
413
414    def addGenerator(self, g):
415        self.generators.append(g)
416        for i in range(100):
417            prev = self._cardinality
418            self._cardinality = None
419            for g in self.generators:
420                g.setCardinality()
421            self.setCardinality()
422            if (self._cardinality is aleph0) or prev==self._cardinality:
423                break
424        else:
425            raise RuntimeError,"Infinite loop in setting cardinality"
426
427    def generateType(self, N):
428        index,M = getNthPairVariableBounds(N, self.bounds)
429        return self.generators[index].get(M)
430
431def test():
432    fbtg = FixedTypeGenerator([BuiltinType('char', 4),
433                               BuiltinType('char', 4, 0),
434                               BuiltinType('int',  4, 5)])
435
436    fields1 = AnyTypeGenerator()
437    fields1.addGenerator( fbtg )
438
439    fields0 = AnyTypeGenerator()
440    fields0.addGenerator( fbtg )
441#    fields0.addGenerator( RecordTypeGenerator(fields1, False, 4) )
442
443    btg = FixedTypeGenerator([BuiltinType('char', 4),
444                              BuiltinType('int',  4)])
445    etg = EnumTypeGenerator([None, '-1', '1', '1u'], 0, 3)
446
447    atg = AnyTypeGenerator()
448    atg.addGenerator( btg )
449    atg.addGenerator( RecordTypeGenerator(fields0, False, 4) )
450    atg.addGenerator( etg )
451    print 'Cardinality:',atg.cardinality
452    for i in range(100):
453        if i == atg.cardinality:
454            try:
455                atg.get(i)
456                raise RuntimeError,"Cardinality was wrong"
457            except AssertionError:
458                break
459        print '%4d: %s'%(i, atg.get(i))
460
461if __name__ == '__main__':
462    test()
463