slang_rs_export_type.cpp revision a41ce1d98094da84643995d40d71c529905123fc
1#include "slang_rs_export_type.h"
2
3#include <vector>
4
5#include "llvm/Type.h"
6#include "llvm/DerivedTypes.h"
7
8#include "llvm/ADT/StringExtras.h"
9#include "llvm/Target/TargetData.h"
10
11#include "clang/AST/RecordLayout.h"
12
13#include "slang_rs_context.h"
14#include "slang_rs_export_element.h"
15
16using namespace slang;
17
18/****************************** RSExportType ******************************/
19bool RSExportType::NormalizeType(const clang::Type *&T,
20                                 llvm::StringRef &TypeName) {
21  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
22      llvm::SmallPtrSet<const clang::Type*, 8>();
23
24  if ((T = RSExportType::TypeExportable(T, SPS)) == NULL)
25    // TODO(zonr): warn that type not exportable.
26    return false;
27
28  // Get type name
29  TypeName = RSExportType::GetTypeName(T);
30  if (TypeName.empty())
31    // TODO(zonr): warning that the type is unnamed.
32    return false;
33
34  return true;
35}
36
37const clang::Type
38*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
39  if (DD) {
40    clang::QualType T;
41    if (DD->getTypeSourceInfo())
42      T = DD->getTypeSourceInfo()->getType();
43    else
44      T = DD->getType();
45
46    if (T.isNull())
47      return NULL;
48    else
49      return T.getTypePtr();
50  }
51  return NULL;
52}
53
54llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
55  T = GET_CANONICAL_TYPE(T);
56  if (T == NULL)
57    return llvm::StringRef();
58
59  switch (T->getTypeClass()) {
60    case clang::Type::Builtin: {
61      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
62
63      switch (BT->getKind()) {
64        // Compiler is smart enough to optimize following *big if branches*
65        // since they all become "constant comparison" after macro expansion
66#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
67        case builtin_type: {                                    \
68          if (type == RSExportPrimitiveType::DataTypeFloat32)           \
69            return "float";                                             \
70          else if (type == RSExportPrimitiveType::DataTypeFloat64)      \
71            return "double";                                            \
72          else if (type == RSExportPrimitiveType::DataTypeUnsigned8)    \
73            return "uchar";                                             \
74          else if (type == RSExportPrimitiveType::DataTypeUnsigned16)   \
75            return "ushort";                                            \
76          else if (type == RSExportPrimitiveType::DataTypeUnsigned32)   \
77            return "uint";                                              \
78          else if (type == RSExportPrimitiveType::DataTypeSigned8)      \
79            return "char";                                              \
80          else if (type == RSExportPrimitiveType::DataTypeSigned16)     \
81            return "short";                                             \
82          else if (type == RSExportPrimitiveType::DataTypeSigned32)     \
83            return "int";                                               \
84          else if (type == RSExportPrimitiveType::DataTypeSigned64)     \
85            return "long";                                              \
86          else if (type == RSExportPrimitiveType::DataTypeBoolean)      \
87            return "bool";                                              \
88          else                                                          \
89            assert(false && "Unknow data type of supported builtin");   \
90          break;                                                        \
91        }
92#include "slang_rs_export_type_support.inc"
93
94          default: {
95            assert(false && "Unknown data type of the builtin");
96            break;
97          }
98        }
99      break;
100    }
101    case clang::Type::Record: {
102      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
103      llvm::StringRef Name = RD->getName();
104      if (Name.empty()) {
105          if (RD->getTypedefForAnonDecl() != NULL)
106            Name = RD->getTypedefForAnonDecl()->getName();
107
108          if (Name.empty())
109            // Try to find a name from redeclaration (i.e. typedef)
110            for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
111                     RE = RD->redecls_end();
112                 RI != RE;
113                 RI++) {
114              assert(*RI != NULL && "cannot be NULL object");
115
116              Name = (*RI)->getName();
117              if (!Name.empty())
118                break;
119            }
120      }
121      return Name;
122    }
123    case clang::Type::Pointer: {
124      // "*" plus pointee name
125      const clang::Type *PT = GET_POINTEE_TYPE(T);
126      llvm::StringRef PointeeName;
127      if (NormalizeType(PT, PointeeName)) {
128        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
129        Name[0] = '*';
130        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
131        Name[PointeeName.size() + 1] = '\0';
132        return Name;
133      }
134      break;
135    }
136    case clang::Type::ExtVector: {
137      const clang::ExtVectorType *EVT =
138          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
139      return RSExportVectorType::GetTypeName(EVT);
140      break;
141    }
142    case clang::Type::ConstantArray : {
143      // Construct name for a constant array is too complicated.
144      return DUMMY_TYPE_NAME_FOR_RS_CONSTANT_ARRAY_TYPE;
145    }
146    default: {
147      break;
148    }
149  }
150
151  return llvm::StringRef();
152}
153
154const clang::Type *RSExportType::TypeExportable(
155    const clang::Type *T,
156    llvm::SmallPtrSet<const clang::Type*, 8>& SPS) {
157  // Normalize first
158  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
159    return NULL;
160
161  if (SPS.count(T))
162    return T;
163
164  switch (T->getTypeClass()) {
165    case clang::Type::Builtin: {
166      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
167
168      switch (BT->getKind()) {
169#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
170        case builtin_type:
171#include "slang_rs_export_type_support.inc"
172        {
173          return T;
174        }
175        default: {
176          return NULL;
177        }
178      }
179      // Never be here
180    }
181    case clang::Type::Record: {
182      if (RSExportPrimitiveType::GetRSObjectType(T) !=
183          RSExportPrimitiveType::DataTypeUnknown)
184        return T;  // RS object type, no further checks are needed
185
186      // Check internal struct
187      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
188      if (RD != NULL)
189        RD = RD->getDefinition();
190
191      // Fast check
192      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
193        return NULL;
194
195      // Insert myself into checking set
196      SPS.insert(T);
197
198      // Check all element
199      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
200               FE = RD->field_end();
201           FI != FE;
202           FI++) {
203        const clang::FieldDecl *FD = *FI;
204        const clang::Type *FT = GetTypeOfDecl(FD);
205        FT = GET_CANONICAL_TYPE(FT);
206
207        if (!TypeExportable(FT, SPS)) {
208          fprintf(stderr, "Field `%s' in Record `%s' contains unsupported "
209                          "type\n", FD->getNameAsString().c_str(),
210                                    RD->getNameAsString().c_str());
211          FT->dump();
212          return NULL;
213        }
214      }
215
216      return T;
217    }
218    case clang::Type::Pointer: {
219      const clang::PointerType *PT = UNSAFE_CAST_TYPE(clang::PointerType, T);
220      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
221
222      if (PointeeType->getTypeClass() == clang::Type::Pointer)
223        return T;
224      // We don't support pointer with array-type pointee or unsupported pointee
225      // type
226      if (PointeeType->isArrayType() ||
227         (TypeExportable(PointeeType, SPS) == NULL) )
228        return NULL;
229      else
230        return T;
231    }
232    case clang::Type::ExtVector: {
233      const clang::ExtVectorType *EVT =
234          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
235      // Only vector with size 2, 3 and 4 are supported.
236      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
237        return NULL;
238
239      // Check base element type
240      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
241
242      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
243          (TypeExportable(ElementType, SPS) == NULL))
244        return NULL;
245      else
246        return T;
247    }
248    case clang::Type::ConstantArray: {
249      const clang::ConstantArrayType *CAT =
250          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
251
252      // Check size
253      if (CAT->getSize().getActiveBits() > 32) {
254        fprintf(stderr, "RSExportConstantArrayType::Create : array with too "
255                        "large size (> 2^32).\n");
256        return NULL;
257      }
258      // Check element type
259      const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
260      if (ElementType->isArrayType()) {
261        fprintf(stderr, "RSExportType::TypeExportable : constant array with 2 "
262                        "or higher dimension of constant is not supported.\n");
263        return NULL;
264      }
265      if (TypeExportable(ElementType, SPS) == NULL)
266        return NULL;
267      else
268        return T;
269    }
270    default: {
271      return NULL;
272    }
273  }
274}
275
276RSExportType *RSExportType::Create(RSContext *Context,
277                                   const clang::Type *T,
278                                   const llvm::StringRef &TypeName) {
279  // Lookup the context to see whether the type was processed before.
280  // Newly created RSExportType will insert into context
281  // in RSExportType::RSExportType()
282  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
283
284  if (ETI != Context->export_types_end())
285    return ETI->second;
286
287  RSExportType *ET = NULL;
288  switch (T->getTypeClass()) {
289    case clang::Type::Record: {
290      RSExportPrimitiveType::DataType dt =
291          RSExportPrimitiveType::GetRSObjectType(TypeName);
292      switch (dt) {
293        case RSExportPrimitiveType::DataTypeUnknown: {
294          // User-defined types
295          ET = RSExportRecordType::Create(Context,
296                                          T->getAsStructureType(),
297                                          TypeName);
298          break;
299        }
300        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
301          // 2 x 2 Matrix type
302          ET = RSExportMatrixType::Create(Context,
303                                          T->getAsStructureType(),
304                                          TypeName,
305                                          2);
306          break;
307        }
308        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
309          // 3 x 3 Matrix type
310          ET = RSExportMatrixType::Create(Context,
311                                          T->getAsStructureType(),
312                                          TypeName,
313                                          3);
314          break;
315        }
316        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
317          // 4 x 4 Matrix type
318          ET = RSExportMatrixType::Create(Context,
319                                          T->getAsStructureType(),
320                                          TypeName,
321                                          4);
322          break;
323        }
324        default: {
325          // Others are primitive types
326          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
327          break;
328        }
329      }
330      break;
331    }
332    case clang::Type::Builtin: {
333      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
334      break;
335    }
336    case clang::Type::Pointer: {
337      ET = RSExportPointerType::Create(Context,
338                                       UNSAFE_CAST_TYPE(clang::PointerType, T),
339                                       TypeName);
340      // FIXME: free the name (allocated in RSExportType::GetTypeName)
341      delete [] TypeName.data();
342      break;
343    }
344    case clang::Type::ExtVector: {
345      ET = RSExportVectorType::Create(Context,
346                                      UNSAFE_CAST_TYPE(clang::ExtVectorType, T),
347                                      TypeName);
348      break;
349    }
350    case clang::Type::ConstantArray: {
351      ET = RSExportConstantArrayType::Create(
352              Context,
353              UNSAFE_CAST_TYPE(clang::ConstantArrayType, T));
354      break;
355    }
356    default: {
357      // TODO(zonr): warn that type is not exportable.
358      fprintf(stderr,
359              "RSExportType::Create : type '%s' is not exportable\n",
360              T->getTypeClassName());
361      break;
362    }
363  }
364
365  return ET;
366}
367
368RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
369  llvm::StringRef TypeName;
370  if (NormalizeType(T, TypeName))
371    return Create(Context, T, TypeName);
372  else
373    return NULL;
374}
375
376RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
377                                           const clang::VarDecl *VD) {
378  return RSExportType::Create(Context, GetTypeOfDecl(VD));
379}
380
381size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
382  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
383      ET->getLLVMType());
384}
385
386size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
387  if (ET->getClass() == RSExportType::ExportClassRecord)
388    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
389  else
390    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
391        ET->getLLVMType());
392}
393
394RSExportType::RSExportType(RSContext *Context,
395                           ExportClass Class,
396                           const llvm::StringRef &Name)
397    : RSExportable(Context, RSExportable::EX_TYPE),
398      mContext(Context),
399      mClass(Class),
400      // Make a copy on Name since memory stored @Name is either allocated in
401      // ASTContext or allocated in GetTypeName which will be destroyed later.
402      mName(Name.data(), Name.size()),
403      mLLVMType(NULL) {
404  // Don't cache the type whose name start with '<'. Those type failed to
405  // get their name since constructing their name in GetTypeName() requiring
406  // complicated work.
407  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
408    // TODO(zonr): Need to check whether the insertion is successful or not.
409    Context->insertExportType(llvm::StringRef(Name), this);
410  return;
411}
412
413/************************** RSExportPrimitiveType **************************/
414llvm::ManagedStatic<RSExportPrimitiveType::RSObjectTypeMapTy>
415RSExportPrimitiveType::RSObjectTypeMap;
416
417llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
418
419bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
420  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
421    return true;
422  else
423    return false;
424}
425
426RSExportPrimitiveType::DataType
427RSExportPrimitiveType::GetRSObjectType(const llvm::StringRef &TypeName) {
428  if (TypeName.empty())
429    return DataTypeUnknown;
430
431  if (RSObjectTypeMap->empty()) {
432#define USE_ELEMENT_DATA_TYPE
433#define DEF_RS_OBJECT_TYPE(type, name)                                  \
434    RSObjectTypeMap->GetOrCreateValue(name, GET_ELEMENT_DATA_TYPE(type));
435#include "slang_rs_export_element_support.inc"
436  }
437
438  RSObjectTypeMapTy::const_iterator I = RSObjectTypeMap->find(TypeName);
439  if (I == RSObjectTypeMap->end())
440    return DataTypeUnknown;
441  else
442    return I->getValue();
443}
444
445RSExportPrimitiveType::DataType
446RSExportPrimitiveType::GetRSObjectType(const clang::Type *T) {
447  T = GET_CANONICAL_TYPE(T);
448  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
449    return DataTypeUnknown;
450
451  return GetRSObjectType( RSExportType::GetTypeName(T) );
452}
453
454const size_t
455RSExportPrimitiveType::SizeOfDataTypeInBits[
456    RSExportPrimitiveType::DataTypeMax + 1] = {
457  16,  // DataTypeFloat16
458  32,  // DataTypeFloat32
459  64,  // DataTypeFloat64
460  8,   // DataTypeSigned8
461  16,  // DataTypeSigned16
462  32,  // DataTypeSigned32
463  64,  // DataTypeSigned64
464  8,   // DataTypeUnsigned8
465  16,  // DataTypeUnsigned16
466  32,  // DataTypeUnsigned32
467  64,  // DataTypeUnSigned64
468  1,   // DataTypeBoolean
469
470  16,  // DataTypeUnsigned565
471  16,  // DataTypeUnsigned5551
472  16,  // DataTypeUnsigned4444
473
474  128,  // DataTypeRSMatrix2x2
475  288,  // DataTypeRSMatrix3x3
476  512,  // DataTypeRSMatrix4x4
477
478  32,  // DataTypeRSElement
479  32,  // DataTypeRSType
480  32,  // DataTypeRSAllocation
481  32,  // DataTypeRSSampler
482  32,  // DataTypeRSScript
483  32,  // DataTypeRSMesh
484  32,  // DataTypeRSProgramFragment
485  32,  // DataTypeRSProgramVertex
486  32,  // DataTypeRSProgramRaster
487  32,  // DataTypeRSProgramStore
488  32,  // DataTypeRSFont
489  0
490};
491
492size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
493  assert(((EPT->getType() >= DataTypeFloat32) &&
494          (EPT->getType() < DataTypeMax)) &&
495         "RSExportPrimitiveType::GetSizeInBits : unknown data type");
496  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
497}
498
499RSExportPrimitiveType::DataType
500RSExportPrimitiveType::GetDataType(const clang::Type *T) {
501  if (T == NULL)
502    return DataTypeUnknown;
503
504  switch (T->getTypeClass()) {
505    case clang::Type::Builtin: {
506      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
507      switch (BT->getKind()) {
508#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
509        case builtin_type: {                                    \
510          return type;                                          \
511          break;                                                \
512        }
513#include "slang_rs_export_type_support.inc"
514
515        // The size of types Long, ULong and WChar depend on platform so we
516        // abandon the support to them. Type of its size exceeds 32 bits (e.g.
517        // int64_t, double, etc.): no support
518
519        default: {
520          // TODO(zonr): warn that the type is unsupported
521          fprintf(stderr, "RSExportPrimitiveType::GetDataType : built-in type "
522                          "has no corresponding data type for built-in type");
523          break;
524        }
525      }
526      break;
527    }
528
529    case clang::Type::Record: {
530      // must be RS object type
531      return RSExportPrimitiveType::GetRSObjectType(T);
532      break;
533    }
534
535    default: {
536      fprintf(stderr, "RSExportPrimitiveType::GetDataType : type '%s' is not "
537                      "supported primitive type", T->getTypeClassName());
538      break;
539    }
540  }
541
542  return DataTypeUnknown;
543}
544
545RSExportPrimitiveType
546*RSExportPrimitiveType::Create(RSContext *Context,
547                               const clang::Type *T,
548                               const llvm::StringRef &TypeName,
549                               DataKind DK,
550                               bool Normalized) {
551  DataType DT = GetDataType(T);
552
553  if ((DT == DataTypeUnknown) || TypeName.empty())
554    return NULL;
555  else
556    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
557                                     DT, DK, Normalized);
558}
559
560RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
561                                                     const clang::Type *T,
562                                                     DataKind DK) {
563  llvm::StringRef TypeName;
564  if (RSExportType::NormalizeType(T, TypeName) && IsPrimitiveType(T))
565    return Create(Context, T, TypeName, DK);
566  else
567    return NULL;
568}
569
570const llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
571  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
572
573  if (isRSObjectType()) {
574    // struct {
575    //   int *p;
576    // } __attribute__((packed, aligned(pointer_size)))
577    //
578    // which is
579    //
580    // <{ [1 x i32] }> in LLVM
581    //
582    if (RSObjectLLVMType == NULL) {
583      std::vector<const llvm::Type *> Elements;
584      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
585      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
586    }
587    return RSObjectLLVMType;
588  }
589
590  switch (mType) {
591    case DataTypeFloat32: {
592      return llvm::Type::getFloatTy(C);
593      break;
594    }
595    case DataTypeFloat64: {
596      return llvm::Type::getDoubleTy(C);
597      break;
598    }
599    case DataTypeBoolean: {
600      return llvm::Type::getInt1Ty(C);
601      break;
602    }
603    case DataTypeSigned8:
604    case DataTypeUnsigned8: {
605      return llvm::Type::getInt8Ty(C);
606      break;
607    }
608    case DataTypeSigned16:
609    case DataTypeUnsigned16:
610    case DataTypeUnsigned565:
611    case DataTypeUnsigned5551:
612    case DataTypeUnsigned4444: {
613      return llvm::Type::getInt16Ty(C);
614      break;
615    }
616    case DataTypeSigned32:
617    case DataTypeUnsigned32: {
618      return llvm::Type::getInt32Ty(C);
619      break;
620    }
621    case DataTypeSigned64: {
622    // case DataTypeUnsigned64:
623      return llvm::Type::getInt64Ty(C);
624      break;
625    }
626    default: {
627      assert(false && "Unknown data type");
628    }
629  }
630
631  return NULL;
632}
633
634/**************************** RSExportPointerType ****************************/
635
636const clang::Type *RSExportPointerType::IntegerType = NULL;
637
638RSExportPointerType
639*RSExportPointerType::Create(RSContext *Context,
640                             const clang::PointerType *PT,
641                             const llvm::StringRef &TypeName) {
642  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
643  const RSExportType *PointeeET;
644
645  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
646    PointeeET = RSExportType::Create(Context, PointeeType);
647  } else {
648    // Double or higher dimension of pointer, export as int*
649    assert(IntegerType != NULL && "Built-in integer type is not set");
650    PointeeET = RSExportPrimitiveType::Create(Context, IntegerType);
651  }
652
653  if (PointeeET == NULL) {
654    fprintf(stderr, "Failed to create type for pointee");
655    return NULL;
656  }
657
658  return new RSExportPointerType(Context, TypeName, PointeeET);
659}
660
661const llvm::Type *RSExportPointerType::convertToLLVMType() const {
662  const llvm::Type *PointeeType = mPointeeType->getLLVMType();
663  return llvm::PointerType::getUnqual(PointeeType);
664}
665
666/***************************** RSExportVectorType *****************************/
667const char* RSExportVectorType::VectorTypeNameStore[][3] = {
668  /* 0 */ { "char2",      "char3",    "char4" },
669  /* 1 */ { "uchar2",     "uchar3",   "uchar4" },
670  /* 2 */ { "short2",     "short3",   "short4" },
671  /* 3 */ { "ushort2",    "ushort3",  "ushort4" },
672  /* 4 */ { "int2",       "int3",     "int4" },
673  /* 5 */ { "uint2",      "uint3",    "uint4" },
674  /* 6 */ { "float2",     "float3",   "float4" },
675  /* 7 */ { "double2",    "double3",  "double4" },
676  /* 8 */ { "long2",      "long3",    "long4" },
677};
678
679llvm::StringRef
680RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
681  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
682
683  if ((ElementType->getTypeClass() != clang::Type::Builtin))
684    return llvm::StringRef();
685
686  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType,
687                                                  ElementType);
688  const char **BaseElement = NULL;
689
690  switch (BT->getKind()) {
691    // Compiler is smart enough to optimize following *big if branches* since
692    // they all become "constant comparison" after macro expansion
693#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
694    case builtin_type: {                                                \
695      if (type == RSExportPrimitiveType::DataTypeSigned8) \
696        BaseElement = VectorTypeNameStore[0];                           \
697      else if (type == RSExportPrimitiveType::DataTypeUnsigned8) \
698        BaseElement = VectorTypeNameStore[1];                           \
699      else if (type == RSExportPrimitiveType::DataTypeSigned16) \
700        BaseElement = VectorTypeNameStore[2];                           \
701      else if (type == RSExportPrimitiveType::DataTypeUnsigned16) \
702        BaseElement = VectorTypeNameStore[3];                           \
703      else if (type == RSExportPrimitiveType::DataTypeSigned32) \
704        BaseElement = VectorTypeNameStore[4];                           \
705      else if (type == RSExportPrimitiveType::DataTypeUnsigned32) \
706        BaseElement = VectorTypeNameStore[5];                           \
707      else if (type == RSExportPrimitiveType::DataTypeFloat32) \
708        BaseElement = VectorTypeNameStore[6];                           \
709      else if (type == RSExportPrimitiveType::DataTypeFloat64) \
710        BaseElement = VectorTypeNameStore[7];                           \
711      else if (type == RSExportPrimitiveType::DataTypeSigned64) \
712        BaseElement = VectorTypeNameStore[8];                           \
713      else if (type == RSExportPrimitiveType::DataTypeBoolean) \
714        BaseElement = VectorTypeNameStore[0];                          \
715      break;  \
716    }
717#include "slang_rs_export_type_support.inc"
718    default: {
719      return llvm::StringRef();
720    }
721  }
722
723  if ((BaseElement != NULL) &&
724      (EVT->getNumElements() > 1) &&
725      (EVT->getNumElements() <= 4))
726    return BaseElement[EVT->getNumElements() - 2];
727  else
728    return llvm::StringRef();
729}
730
731RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
732                                               const clang::ExtVectorType *EVT,
733                                               const llvm::StringRef &TypeName,
734                                               DataKind DK,
735                                               bool Normalized) {
736  assert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
737
738  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
739  RSExportPrimitiveType::DataType DT =
740      RSExportPrimitiveType::GetDataType(ElementType);
741
742  if (DT != RSExportPrimitiveType::DataTypeUnknown)
743    return new RSExportVectorType(Context,
744                                  TypeName,
745                                  DT,
746                                  DK,
747                                  Normalized,
748                                  EVT->getNumElements());
749  else
750    fprintf(stderr, "RSExportVectorType::Create : unsupported base element "
751                    "type\n");
752  return NULL;
753}
754
755const llvm::Type *RSExportVectorType::convertToLLVMType() const {
756  const llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
757  return llvm::VectorType::get(ElementType, getNumElement());
758}
759
760/***************************** RSExportMatrixType *****************************/
761RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
762                                               const clang::RecordType *RT,
763                                               const llvm::StringRef &TypeName,
764                                               unsigned Dim) {
765  assert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
766  assert((Dim > 1) && "Invalid dimension of matrix");
767
768  // Check whether the struct rs_matrix is in our expected form (but assume it's
769  // correct if we're not sure whether it's correct or not)
770  const clang::RecordDecl* RD = RT->getDecl();
771  RD = RD->getDefinition();
772  if (RD != NULL) {
773    // Find definition, perform further examination
774    if (RD->field_empty()) {
775      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
776                      "must have 1 field for saving values", TypeName.data());
777      return NULL;
778    }
779
780    clang::RecordDecl::field_iterator FIT = RD->field_begin();
781    const clang::FieldDecl *FD = *FIT;
782    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
783    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
784      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
785                      "first field should be an array with constant size",
786              TypeName.data());
787      return NULL;
788    }
789    const clang::ConstantArrayType *CAT =
790      static_cast<const clang::ConstantArrayType *>(FT);
791    const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
792    if ((ElementType == NULL) ||
793        (ElementType->getTypeClass() != clang::Type::Builtin) ||
794        (static_cast<const clang::BuiltinType *>(ElementType)->getKind()
795          != clang::BuiltinType::Float)) {
796      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
797                      "first field should be a float array", TypeName.data());
798      return NULL;
799    }
800
801    if (CAT->getSize() != Dim * Dim) {
802      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
803                      "first field should be an array with size %d",
804              TypeName.data(), Dim * Dim);
805      return NULL;
806    }
807
808    FIT++;
809    if (FIT != RD->field_end()) {
810      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
811                      "must have exactly 1 field", TypeName.data());
812      return NULL;
813    }
814  }
815
816  return new RSExportMatrixType(Context, TypeName, Dim);
817}
818
819const llvm::Type *RSExportMatrixType::convertToLLVMType() const {
820  // Construct LLVM type:
821  // struct {
822  //  float X[mDim * mDim];
823  // }
824
825  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
826  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
827                                            mDim * mDim);
828  return llvm::StructType::get(C, X, NULL);
829}
830
831/************************* RSExportConstantArrayType *************************/
832RSExportConstantArrayType
833*RSExportConstantArrayType::Create(RSContext *Context,
834                                   const clang::ConstantArrayType *CAT) {
835  assert(CAT != NULL && CAT->getTypeClass() == clang::Type::ConstantArray);
836
837  assert((CAT->getSize().getActiveBits() < 32) && "array too large");
838
839  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
840  assert((Size > 0) && "Constant array should have size greater than 0");
841
842  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
843  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
844
845  if (ElementET == NULL) {
846    fprintf(stderr, "RSExportConstantArrayType::Create : failed to create "
847                    "RSExportType for array element.\n");
848    return NULL;
849  }
850
851  return new RSExportConstantArrayType(Context,
852                                       ElementET,
853                                       Size);
854}
855
856const llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
857  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
858}
859
860/**************************** RSExportRecordType ****************************/
861RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
862                                               const clang::RecordType *RT,
863                                               const llvm::StringRef &TypeName,
864                                               bool mIsArtificial) {
865  assert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
866
867  const clang::RecordDecl *RD = RT->getDecl();
868  assert(RD->isStruct());
869
870  RD = RD->getDefinition();
871  if (RD == NULL) {
872    // TODO(zonr): warn that actual struct definition isn't declared in this
873    //             moudle.
874    fprintf(stderr, "RSExportRecordType::Create : this struct is not defined "
875                    "in this module.");
876    return NULL;
877  }
878
879  // Struct layout construct by clang. We rely on this for obtaining the
880  // alloc size of a struct and offset of every field in that struct.
881  const clang::ASTRecordLayout *RL =
882      &Context->getASTContext()->getASTRecordLayout(RD);
883  assert((RL != NULL) && "Failed to retrieve the struct layout from Clang.");
884
885  RSExportRecordType *ERT =
886      new RSExportRecordType(Context,
887                             TypeName,
888                             RD->hasAttr<clang::PackedAttr>(),
889                             mIsArtificial,
890                             (RL->getSize() >> 3));
891  unsigned int Index = 0;
892
893  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
894           FE = RD->field_end();
895       FI != FE;
896       FI++, Index++) {
897#define FAILED_CREATE_FIELD(err)    do {         \
898      if (*err)                                                          \
899        fprintf(stderr, \
900                "RSExportRecordType::Create : failed to create field (%s)\n", \
901                err);                                                   \
902      delete ERT;                                                       \
903      return NULL;                                                      \
904    } while (false)
905
906    // FIXME: All fields should be primitive type
907    assert((*FI)->getKind() == clang::Decl::Field);
908    clang::FieldDecl *FD = *FI;
909
910    // We don't support bit field
911    //
912    // TODO(zonr): allow bitfield with size 8, 16, 32
913    if (FD->isBitField())
914      FAILED_CREATE_FIELD("bit field is not supported");
915
916    // Type
917    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
918
919    if (ET != NULL)
920      ERT->mFields.push_back(
921          new Field(ET, FD->getName(), ERT,
922                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
923    else
924      FAILED_CREATE_FIELD(FD->getName().str().c_str());
925#undef FAILED_CREATE_FIELD
926  }
927
928  return ERT;
929}
930
931const llvm::Type *RSExportRecordType::convertToLLVMType() const {
932  std::vector<const llvm::Type*> FieldTypes;
933
934  for (const_field_iterator FI = fields_begin(),
935           FE = fields_end();
936       FI != FE;
937       FI++) {
938    const Field *F = *FI;
939    const RSExportType *FET = F->getType();
940
941    FieldTypes.push_back(FET->getLLVMType());
942  }
943
944  return llvm::StructType::get(getRSContext()->getLLVMContext(),
945                               FieldTypes,
946                               mIsPacked);
947}
948