slang_rs_export_type.cpp revision 0da0a7dc51c25943fe31d0bfccbdfee326a3199c
1#include "slang_rs_export_type.h"
2
3#include <vector>
4
5#include "llvm/Type.h"
6#include "llvm/DerivedTypes.h"
7
8#include "llvm/ADT/StringExtras.h"
9#include "llvm/Target/TargetData.h"
10
11#include "clang/AST/RecordLayout.h"
12
13#include "slang_rs_context.h"
14#include "slang_rs_export_element.h"
15
16using namespace slang;
17
18/****************************** RSExportType ******************************/
19bool RSExportType::NormalizeType(const clang::Type *&T,
20                                 llvm::StringRef &TypeName) {
21  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
22      llvm::SmallPtrSet<const clang::Type*, 8>();
23
24  if ((T = RSExportType::TypeExportable(T, SPS)) == NULL)
25    // TODO(zonr): warn that type not exportable.
26    return false;
27
28  // Get type name
29  TypeName = RSExportType::GetTypeName(T);
30  if (TypeName.empty())
31    // TODO(zonr): warning that the type is unnamed.
32    return false;
33
34  return true;
35}
36
37const clang::Type
38*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
39  if (DD) {
40    clang::QualType T;
41    if (DD->getTypeSourceInfo())
42      T = DD->getTypeSourceInfo()->getType();
43    else
44      T = DD->getType();
45
46    if (T.isNull())
47      return NULL;
48    else
49      return T.getTypePtr();
50  }
51  return NULL;
52}
53
54llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
55  T = GET_CANONICAL_TYPE(T);
56  if (T == NULL)
57    return llvm::StringRef();
58
59  switch (T->getTypeClass()) {
60    case clang::Type::Builtin: {
61      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
62
63      switch (BT->getKind()) {
64        // Compiler is smart enough to optimize following *big if branches*
65        // since they all become "constant comparison" after macro expansion
66#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
67        case builtin_type: {                                    \
68          if (type == RSExportPrimitiveType::DataTypeFloat32)           \
69            return "float";                                             \
70          else if (type == RSExportPrimitiveType::DataTypeFloat64)      \
71            return "double";                                            \
72          else if (type == RSExportPrimitiveType::DataTypeUnsigned8)    \
73            return "uchar";                                             \
74          else if (type == RSExportPrimitiveType::DataTypeUnsigned16)   \
75            return "ushort";                                            \
76          else if (type == RSExportPrimitiveType::DataTypeUnsigned32)   \
77            return "uint";                                              \
78          else if (type == RSExportPrimitiveType::DataTypeSigned8)      \
79            return "char";                                              \
80          else if (type == RSExportPrimitiveType::DataTypeSigned16)     \
81            return "short";                                             \
82          else if (type == RSExportPrimitiveType::DataTypeSigned32)     \
83            return "int";                                               \
84          else if (type == RSExportPrimitiveType::DataTypeSigned64)     \
85            return "long";                                              \
86          else if (type == RSExportPrimitiveType::DataTypeBoolean)      \
87            return "bool";                                              \
88          else                                                          \
89            assert(false && "Unknow data type of supported builtin");   \
90          break;                                                        \
91        }
92#include "slang_rs_export_type_support.inc"
93
94          default: {
95            assert(false && "Unknown data type of the builtin");
96            break;
97          }
98        }
99      break;
100    }
101    case clang::Type::Record: {
102      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
103      llvm::StringRef Name = RD->getName();
104      if (Name.empty()) {
105          if (RD->getTypedefForAnonDecl() != NULL)
106            Name = RD->getTypedefForAnonDecl()->getName();
107
108          if (Name.empty())
109            // Try to find a name from redeclaration (i.e. typedef)
110            for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
111                     RE = RD->redecls_end();
112                 RI != RE;
113                 RI++) {
114              assert(*RI != NULL && "cannot be NULL object");
115
116              Name = (*RI)->getName();
117              if (!Name.empty())
118                break;
119            }
120      }
121      return Name;
122    }
123    case clang::Type::Pointer: {
124      // "*" plus pointee name
125      const clang::Type *PT = GET_POINTEE_TYPE(T);
126      llvm::StringRef PointeeName;
127      if (NormalizeType(PT, PointeeName)) {
128        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
129        Name[0] = '*';
130        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
131        Name[PointeeName.size() + 1] = '\0';
132        return Name;
133      }
134      break;
135    }
136    case clang::Type::ConstantArray: {
137      const clang::ConstantArrayType *ECT =
138          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
139      return RSExportConstantArrayType::GetTypeName(ECT);
140      break;
141    }
142    case clang::Type::ExtVector: {
143      const clang::ExtVectorType *EVT =
144          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
145      return RSExportVectorType::GetTypeName(EVT);
146      break;
147    }
148    default: {
149      break;
150    }
151  }
152
153  return llvm::StringRef();
154}
155
156const clang::Type *RSExportType::TypeExportable(
157    const clang::Type *T,
158    llvm::SmallPtrSet<const clang::Type*, 8>& SPS) {
159  // Normalize first
160  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
161    return NULL;
162
163  if (SPS.count(T))
164    return T;
165
166  switch (T->getTypeClass()) {
167    case clang::Type::Builtin: {
168      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
169
170      switch (BT->getKind()) {
171#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
172        case builtin_type:
173#include "slang_rs_export_type_support.inc"
174        {
175          return T;
176        }
177        default: {
178          return NULL;
179        }
180      }
181      // Never be here
182    }
183    case clang::Type::Record: {
184      if (RSExportPrimitiveType::GetRSObjectType(T) !=
185          RSExportPrimitiveType::DataTypeUnknown)
186        return T;  // RS object type, no further checks are needed
187
188      // Check internal struct
189      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
190      if (RD != NULL)
191        RD = RD->getDefinition();
192
193      // Fast check
194      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
195        return NULL;
196
197      // Insert myself into checking set
198      SPS.insert(T);
199
200      // Check all element
201      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
202               FE = RD->field_end();
203           FI != FE;
204           FI++) {
205        const clang::FieldDecl *FD = *FI;
206        const clang::Type *FT = GetTypeOfDecl(FD);
207        FT = GET_CANONICAL_TYPE(FT);
208
209        if (!TypeExportable(FT, SPS)) {
210          fprintf(stderr, "Field `%s' in Record `%s' contains unsupported "
211                          "type\n", FD->getNameAsString().c_str(),
212                                    RD->getNameAsString().c_str());
213          FT->dump();
214          return NULL;
215        }
216      }
217
218      return T;
219    }
220    case clang::Type::Pointer: {
221      const clang::PointerType *PT = UNSAFE_CAST_TYPE(clang::PointerType, T);
222      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
223
224      if ((PointeeType->getTypeClass() != clang::Type::Pointer) &&
225         (TypeExportable(PointeeType, SPS) == NULL) )
226        return NULL;
227      else
228        return T;
229    }
230
231    case clang::Type::ConstantArray: {
232      const clang::ConstantArrayType *ECT =
233          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
234      // No longer only support 2x2, 3x3 and 4x4 arrays
235      // if (ECT->getNumElements() != 4 &&
236      //     ECT->getNumElements() != 9 &&
237      //     ECT->getNumElements() != 16)
238      //  return NULL;
239
240      // Check base element type
241      const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(ECT);
242
243      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
244          (TypeExportable(ElementType, SPS) == NULL))
245        return NULL;
246      else
247        return T;
248    }
249    case clang::Type::ExtVector: {
250      const clang::ExtVectorType *EVT =
251          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
252      // Only vector with size 2, 3 and 4 are supported.
253      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
254        return NULL;
255
256      // Check base element type
257      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
258
259      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
260          (TypeExportable(ElementType, SPS) == NULL))
261        return NULL;
262      else
263        return T;
264    }
265    default: {
266      return NULL;
267    }
268  }
269}
270
271RSExportType *RSExportType::Create(RSContext *Context,
272                                   const clang::Type *T,
273                                   const llvm::StringRef &TypeName) {
274  // Lookup the context to see whether the type was processed before.
275  // Newly created RSExportType will insert into context
276  // in RSExportType::RSExportType()
277  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
278
279  if (ETI != Context->export_types_end())
280    return ETI->second;
281
282  RSExportType *ET = NULL;
283  switch (T->getTypeClass()) {
284    case clang::Type::Record: {
285      RSExportPrimitiveType::DataType dt =
286          RSExportPrimitiveType::GetRSObjectType(TypeName);
287      switch (dt) {
288        case RSExportPrimitiveType::DataTypeUnknown: {
289          // User-defined types
290          ET = RSExportRecordType::Create(Context,
291                                          T->getAsStructureType(),
292                                          TypeName);
293          break;
294        }
295        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
296          // 2 x 2 Matrix type
297          ET = RSExportMatrixType::Create(Context,
298                                          T->getAsStructureType(),
299                                          TypeName,
300                                          2);
301          break;
302        }
303        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
304          // 3 x 3 Matrix type
305          ET = RSExportMatrixType::Create(Context,
306                                          T->getAsStructureType(),
307                                          TypeName,
308                                          3);
309          break;
310        }
311        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
312          // 4 x 4 Matrix type
313          ET = RSExportMatrixType::Create(Context,
314                                          T->getAsStructureType(),
315                                          TypeName,
316                                          4);
317          break;
318        }
319        default: {
320          // Others are primitive types
321          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
322          break;
323        }
324      }
325      break;
326    }
327    case clang::Type::Builtin: {
328      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
329      break;
330    }
331    case clang::Type::Pointer: {
332      ET = RSExportPointerType::Create(Context,
333                                       UNSAFE_CAST_TYPE(clang::PointerType, T),
334                                       TypeName);
335      // FIXME: free the name (allocated in RSExportType::GetTypeName)
336      delete [] TypeName.data();
337      break;
338    }
339    case clang::Type::ConstantArray: {
340      ET = RSExportConstantArrayType::Create(
341          Context,
342          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T),
343          TypeName);
344      break;
345    }
346    case clang::Type::ExtVector: {
347      ET = RSExportVectorType::Create(Context,
348                                      UNSAFE_CAST_TYPE(clang::ExtVectorType, T),
349                                      TypeName);
350      break;
351    }
352    default: {
353      // TODO(zonr): warn that type is not exportable.
354      fprintf(stderr,
355              "RSExportType::Create : type '%s' is not exportable\n",
356              T->getTypeClassName());
357      break;
358    }
359  }
360
361  return ET;
362}
363
364RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
365  llvm::StringRef TypeName;
366  if (NormalizeType(T, TypeName))
367    return Create(Context, T, TypeName);
368  else
369    return NULL;
370}
371
372RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
373                                           const clang::VarDecl *VD) {
374  return RSExportType::Create(Context, GetTypeOfDecl(VD));
375}
376
377size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
378  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
379      ET->getLLVMType());
380}
381
382size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
383  if (ET->getClass() == RSExportType::ExportClassRecord)
384    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
385  else
386    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
387        ET->getLLVMType());
388}
389
390RSExportType::RSExportType(RSContext *Context, const llvm::StringRef &Name)
391    : mContext(Context),
392      // Make a copy on Name since memory stored @Name is either allocated in
393      // ASTContext or allocated in GetTypeName which will be destroyed later.
394      mName(Name.data(), Name.size()),
395      mLLVMType(NULL) {
396  // Don't cache the type whose name start with '<'. Those type failed to
397  // get their name since constructing their name in GetTypeName() requiring
398  // complicated work.
399  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
400    // TODO(zonr): Need to check whether the insertion is successful or not.
401    Context->insertExportType(llvm::StringRef(Name), this);
402  return;
403}
404
405/************************** RSExportPrimitiveType **************************/
406RSExportPrimitiveType::RSObjectTypeMapTy
407*RSExportPrimitiveType::RSObjectTypeMap = NULL;
408
409llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
410
411bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
412  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
413    return true;
414  else
415    return false;
416}
417
418RSExportPrimitiveType::DataType
419RSExportPrimitiveType::GetRSObjectType(const llvm::StringRef &TypeName) {
420  if (TypeName.empty())
421    return DataTypeUnknown;
422
423  if (RSObjectTypeMap == NULL) {
424    RSObjectTypeMap = new RSObjectTypeMapTy(16);
425
426#define USE_ELEMENT_DATA_TYPE
427#define DEF_RS_OBJECT_TYPE(type, name)                                  \
428    RSObjectTypeMap->GetOrCreateValue(name, GET_ELEMENT_DATA_TYPE(type));
429#include "slang_rs_export_element_support.inc"
430  }
431
432  RSObjectTypeMapTy::const_iterator I = RSObjectTypeMap->find(TypeName);
433  if (I == RSObjectTypeMap->end())
434    return DataTypeUnknown;
435  else
436    return I->getValue();
437}
438
439RSExportPrimitiveType::DataType
440RSExportPrimitiveType::GetRSObjectType(const clang::Type *T) {
441  T = GET_CANONICAL_TYPE(T);
442  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
443    return DataTypeUnknown;
444
445  return GetRSObjectType( RSExportType::GetTypeName(T) );
446}
447
448const size_t
449RSExportPrimitiveType::SizeOfDataTypeInBits[
450    RSExportPrimitiveType::DataTypeMax + 1] = {
451  16,  // DataTypeFloat16
452  32,  // DataTypeFloat32
453  64,  // DataTypeFloat64
454  8,   // DataTypeSigned8
455  16,  // DataTypeSigned16
456  32,  // DataTypeSigned32
457  64,  // DataTypeSigned64
458  8,   // DataTypeUnsigned8
459  16,  // DataTypeUnsigned16
460  32,  // DataTypeUnsigned32
461  64,  // DataTypeUnSigned64
462  1,   // DataTypeBoolean
463
464  16,  // DataTypeUnsigned565
465  16,  // DataTypeUnsigned5551
466  16,  // DataTypeUnsigned4444
467
468  128,  // DataTypeRSMatrix2x2
469  288,  // DataTypeRSMatrix3x3
470  512,  // DataTypeRSMatrix4x4
471
472  32,  // DataTypeRSElement
473  32,  // DataTypeRSType
474  32,  // DataTypeRSAllocation
475  32,  // DataTypeRSSampler
476  32,  // DataTypeRSScript
477  32,  // DataTypeRSMesh
478  32,  // DataTypeRSProgramFragment
479  32,  // DataTypeRSProgramVertex
480  32,  // DataTypeRSProgramRaster
481  32,  // DataTypeRSProgramStore
482  32,  // DataTypeRSFont
483  0
484};
485
486size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
487  assert(((EPT->getType() >= DataTypeFloat32) &&
488          (EPT->getType() < DataTypeMax)) &&
489         "RSExportPrimitiveType::GetSizeInBits : unknown data type");
490  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
491}
492
493RSExportPrimitiveType::DataType
494RSExportPrimitiveType::GetDataType(const clang::Type *T) {
495  if (T == NULL)
496    return DataTypeUnknown;
497
498  switch (T->getTypeClass()) {
499    case clang::Type::Builtin: {
500      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
501      switch (BT->getKind()) {
502#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
503        case builtin_type: {                                    \
504          return type;                                          \
505          break;                                                \
506        }
507#include "slang_rs_export_type_support.inc"
508
509        // The size of types Long, ULong and WChar depend on platform so we
510        // abandon the support to them. Type of its size exceeds 32 bits (e.g.
511        // int64_t, double, etc.): no support
512
513        default: {
514          // TODO(zonr): warn that the type is unsupported
515          fprintf(stderr, "RSExportPrimitiveType::GetDataType : built-in type "
516                          "has no corresponding data type for built-in type");
517          break;
518        }
519      }
520      break;
521    }
522
523    case clang::Type::Record: {
524      // must be RS object type
525      return RSExportPrimitiveType::GetRSObjectType(T);
526      break;
527    }
528
529    default: {
530      fprintf(stderr, "RSExportPrimitiveType::GetDataType : type '%s' is not "
531                      "supported primitive type", T->getTypeClassName());
532      break;
533    }
534  }
535
536  return DataTypeUnknown;
537}
538
539RSExportPrimitiveType
540*RSExportPrimitiveType::Create(RSContext *Context,
541                               const clang::Type *T,
542                               const llvm::StringRef &TypeName,
543                               DataKind DK,
544                               bool Normalized) {
545  DataType DT = GetDataType(T);
546
547  if ((DT == DataTypeUnknown) || TypeName.empty())
548    return NULL;
549  else
550    return new RSExportPrimitiveType(Context, TypeName, DT, DK, Normalized);
551}
552
553RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
554                                                     const clang::Type *T,
555                                                     DataKind DK) {
556  llvm::StringRef TypeName;
557  if (RSExportType::NormalizeType(T, TypeName) && IsPrimitiveType(T))
558    return Create(Context, T, TypeName, DK);
559  else
560    return NULL;
561}
562
563RSExportType::ExportClass RSExportPrimitiveType::getClass() const {
564  return RSExportType::ExportClassPrimitive;
565}
566
567const llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
568  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
569
570  if (isRSObjectType()) {
571    // struct {
572    //   int *p;
573    // } __attribute__((packed, aligned(pointer_size)))
574    //
575    // which is
576    //
577    // <{ [1 x i32] }> in LLVM
578    //
579    if (RSObjectLLVMType == NULL) {
580      std::vector<const llvm::Type *> Elements;
581      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
582      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
583    }
584    return RSObjectLLVMType;
585  }
586
587  switch (mType) {
588    case DataTypeFloat32: {
589      return llvm::Type::getFloatTy(C);
590      break;
591    }
592    case DataTypeFloat64: {
593      return llvm::Type::getDoubleTy(C);
594      break;
595    }
596    case DataTypeBoolean: {
597      return llvm::Type::getInt1Ty(C);
598      break;
599    }
600    case DataTypeSigned8:
601    case DataTypeUnsigned8: {
602      return llvm::Type::getInt8Ty(C);
603      break;
604    }
605    case DataTypeSigned16:
606    case DataTypeUnsigned16:
607    case DataTypeUnsigned565:
608    case DataTypeUnsigned5551:
609    case DataTypeUnsigned4444: {
610      return llvm::Type::getInt16Ty(C);
611      break;
612    }
613    case DataTypeSigned32:
614    case DataTypeUnsigned32: {
615      return llvm::Type::getInt32Ty(C);
616      break;
617    }
618    case DataTypeSigned64: {
619    // case DataTypeUnsigned64:
620      return llvm::Type::getInt64Ty(C);
621      break;
622    }
623    default: {
624      assert(false && "Unknown data type");
625    }
626  }
627
628  return NULL;
629}
630
631/**************************** RSExportPointerType ****************************/
632
633const clang::Type *RSExportPointerType::IntegerType = NULL;
634
635RSExportPointerType
636*RSExportPointerType::Create(RSContext *Context,
637                             const clang::PointerType *PT,
638                             const llvm::StringRef &TypeName) {
639  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
640  const RSExportType *PointeeET;
641
642  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
643    PointeeET = RSExportType::Create(Context, PointeeType);
644  } else {
645    // Double or higher dimension of pointer, export as int*
646    assert(IntegerType != NULL && "Built-in integer type is not set");
647    PointeeET = RSExportPrimitiveType::Create(Context, IntegerType);
648  }
649
650  if (PointeeET == NULL) {
651    fprintf(stderr, "Failed to create type for pointee");
652    return NULL;
653  }
654
655  return new RSExportPointerType(Context, TypeName, PointeeET);
656}
657
658RSExportType::ExportClass RSExportPointerType::getClass() const {
659  return RSExportType::ExportClassPointer;
660}
661
662const llvm::Type *RSExportPointerType::convertToLLVMType() const {
663  const llvm::Type *PointeeType = mPointeeType->getLLVMType();
664  return llvm::PointerType::getUnqual(PointeeType);
665}
666
667/************************* RSExportConstantArrayType *************************/
668llvm::StringRef
669RSExportConstantArrayType::GetTypeName(const clang::ConstantArrayType *CT) {
670  llvm::APInt i = CT->getSize();
671  if (i == 4) {
672    return llvm::StringRef("rs_matrix2x2");
673  } else if (i == 9) {
674    return llvm::StringRef("rs_matrix3x3");
675  } else if (i == 16) {
676    return llvm::StringRef("rs_matrix4x4");
677  }
678  return llvm::StringRef();
679}
680
681RSExportConstantArrayType
682*RSExportConstantArrayType::Create(RSContext *Context,
683                                   const clang::ConstantArrayType *CT,
684                                   const llvm::StringRef &TypeName,
685                                   DataKind DK,
686                                   bool Normalized) {
687  assert(CT != NULL && CT->getTypeClass() == clang::Type::ConstantArray);
688
689  int64_t Size = CT->getSize().getSExtValue();
690  RSExportPrimitiveType::DataType DT;
691  if (Size == 4) {
692    DT = RSExportPrimitiveType::DataTypeRSMatrix2x2;
693  } else if (Size == 9) {
694    DT = RSExportPrimitiveType::DataTypeRSMatrix3x3;
695  } else if (Size == 16) {
696    DT = RSExportPrimitiveType::DataTypeRSMatrix4x4;
697  } else {
698    fprintf(stderr, "RSExportConstantArrayType::Create : unsupported base "
699                    "element type\n");
700    return NULL;
701  }
702
703  return new RSExportConstantArrayType(Context,
704                                       TypeName,
705                                       DT,
706                                       DK,
707                                       Normalized,
708                                       Size);
709}
710
711RSExportType::ExportClass RSExportConstantArrayType::getClass() const {
712  return RSExportType::ExportClassConstantArray;
713}
714
715const llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
716  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
717  const llvm::Type *typ;
718  switch (getType()) {
719    case DataTypeFloat32:
720    case DataTypeRSMatrix2x2:
721    case DataTypeRSMatrix3x3:
722    case DataTypeRSMatrix4x4: {
723      typ = llvm::Type::getFloatTy(C);
724      break;
725    }
726    case DataTypeFloat64: {
727      typ = llvm::Type::getDoubleTy(C);
728      break;
729    }
730    case DataTypeBoolean: {
731      typ = llvm::Type::getInt1Ty(C);
732      break;
733    }
734    case DataTypeSigned8:
735    case DataTypeUnsigned8: {
736      typ = llvm::Type::getInt8Ty(C);
737      break;
738    }
739    case DataTypeSigned16:
740    case DataTypeUnsigned16:
741    case DataTypeUnsigned565:
742    case DataTypeUnsigned5551:
743    case DataTypeUnsigned4444: {
744      typ = llvm::Type::getInt16Ty(C);
745      break;
746    }
747    case DataTypeSigned32:
748    case DataTypeUnsigned32: {
749      typ = llvm::Type::getInt32Ty(C);
750      break;
751    }
752    case DataTypeSigned64: {
753    //case DataTypeUnsigned64:
754      typ = llvm::Type::getInt64Ty(C);
755      break;
756    }
757    default: {
758      assert(false && "Unknown data type");
759      break;
760    }
761  }
762  return llvm::ArrayType::get(typ, mNumElement);
763}
764
765/***************************** RSExportVectorType *****************************/
766const char* RSExportVectorType::VectorTypeNameStore[][3] = {
767  /* 0 */ { "char2",      "char3",    "char4" },
768  /* 1 */ { "uchar2",     "uchar3",   "uchar4" },
769  /* 2 */ { "short2",     "short3",   "short4" },
770  /* 3 */ { "ushort2",    "ushort3",  "ushort4" },
771  /* 4 */ { "int2",       "int3",     "int4" },
772  /* 5 */ { "uint2",      "uint3",    "uint4" },
773  /* 6 */ { "float2",     "float3",   "float4" },
774  /* 7 */ { "double2",    "double3",  "double4" },
775  /* 8 */ { "long2",      "long3",    "long4" },
776};
777
778llvm::StringRef
779RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
780  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
781
782  if ((ElementType->getTypeClass() != clang::Type::Builtin))
783    return llvm::StringRef();
784
785  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType,
786                                                  ElementType);
787  const char **BaseElement = NULL;
788
789  switch (BT->getKind()) {
790    // Compiler is smart enough to optimize following *big if branches* since
791    // they all become "constant comparison" after macro expansion
792#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
793    case builtin_type: {                                                \
794      if (type == RSExportPrimitiveType::DataTypeSigned8) \
795        BaseElement = VectorTypeNameStore[0];                           \
796      else if (type == RSExportPrimitiveType::DataTypeUnsigned8) \
797        BaseElement = VectorTypeNameStore[1];                           \
798      else if (type == RSExportPrimitiveType::DataTypeSigned16) \
799        BaseElement = VectorTypeNameStore[2];                           \
800      else if (type == RSExportPrimitiveType::DataTypeUnsigned16) \
801        BaseElement = VectorTypeNameStore[3];                           \
802      else if (type == RSExportPrimitiveType::DataTypeSigned32) \
803        BaseElement = VectorTypeNameStore[4];                           \
804      else if (type == RSExportPrimitiveType::DataTypeUnsigned32) \
805        BaseElement = VectorTypeNameStore[5];                           \
806      else if (type == RSExportPrimitiveType::DataTypeFloat32) \
807        BaseElement = VectorTypeNameStore[6];                           \
808      else if (type == RSExportPrimitiveType::DataTypeFloat64) \
809        BaseElement = VectorTypeNameStore[7];                           \
810      else if (type == RSExportPrimitiveType::DataTypeSigned64) \
811        BaseElement = VectorTypeNameStore[8];                           \
812      else if (type == RSExportPrimitiveType::DataTypeBoolean) \
813        BaseElement = VectorTypeNameStore[0];                          \
814      break;  \
815    }
816#include "slang_rs_export_type_support.inc"
817    default: {
818      return llvm::StringRef();
819    }
820  }
821
822  if ((BaseElement != NULL) &&
823      (EVT->getNumElements() > 1) &&
824      (EVT->getNumElements() <= 4))
825    return BaseElement[EVT->getNumElements() - 2];
826  else
827    return llvm::StringRef();
828}
829
830RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
831                                               const clang::ExtVectorType *EVT,
832                                               const llvm::StringRef &TypeName,
833                                               DataKind DK,
834                                               bool Normalized) {
835  assert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
836
837  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
838  RSExportPrimitiveType::DataType DT =
839      RSExportPrimitiveType::GetDataType(ElementType);
840
841  if (DT != RSExportPrimitiveType::DataTypeUnknown)
842    return new RSExportVectorType(Context,
843                                  TypeName,
844                                  DT,
845                                  DK,
846                                  Normalized,
847                                  EVT->getNumElements());
848  else
849    fprintf(stderr, "RSExportVectorType::Create : unsupported base element "
850                    "type\n");
851  return NULL;
852}
853
854RSExportType::ExportClass RSExportVectorType::getClass() const {
855  return RSExportType::ExportClassVector;
856}
857
858const llvm::Type *RSExportVectorType::convertToLLVMType() const {
859  const llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
860  return llvm::VectorType::get(ElementType, getNumElement());
861}
862
863/***************************** RSExportMatrixType *****************************/
864RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
865                                               const clang::RecordType *RT,
866                                               const llvm::StringRef &TypeName,
867                                               unsigned Dim) {
868  assert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
869  assert((Dim > 1) && "Invalid dimension of matrix");
870
871  // Check whether the struct rs_matrix is in our expected form (but assume it's
872  // correct if we're not sure whether it's correct or not)
873  const clang::RecordDecl* RD = RT->getDecl();
874  RD = RD->getDefinition();
875  if (RD != NULL) {
876    // Find definition, perform further examination
877    if (RD->field_empty()) {
878      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
879                      "must have 1 field for saving values", TypeName.data());
880      return NULL;
881    }
882
883    clang::RecordDecl::field_iterator FIT = RD->field_begin();
884    const clang::FieldDecl *FD = *FIT;
885    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
886    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
887      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
888                      "first field should be an array with constant size",
889              TypeName.data());
890      return NULL;
891    }
892    const clang::ConstantArrayType *CAT =
893      static_cast<const clang::ConstantArrayType *>(FT);
894    const clang::Type *ElementType =
895      GET_CANONICAL_TYPE(CAT->getElementType().getTypePtr());
896    if ((ElementType == NULL) ||
897        (ElementType->getTypeClass() != clang::Type::Builtin) ||
898        (static_cast<const clang::BuiltinType *>(ElementType)->getKind()
899          != clang::BuiltinType::Float)) {
900      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
901                      "first field should be a float array", TypeName.data());
902      return NULL;
903    }
904
905    if (CAT->getSize() != Dim * Dim) {
906      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
907                      "first field should be an array with size %d",
908              TypeName.data(), Dim * Dim);
909      return NULL;
910    }
911
912    FIT++;
913    if (FIT != RD->field_end()) {
914      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
915                      "must have exactly 1 field", TypeName.data());
916      return NULL;
917    }
918  }
919
920  return new RSExportMatrixType(Context, TypeName, Dim);
921}
922
923RSExportType::ExportClass RSExportMatrixType::getClass() const {
924  return RSExportType::ExportClassMatrix;
925}
926
927const llvm::Type *RSExportMatrixType::convertToLLVMType() const {
928  // Construct LLVM type:
929  // struct {
930  //  float X[mDim * mDim];
931  // }
932
933  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
934  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
935                                            mDim * mDim);
936  return llvm::StructType::get(C, X, NULL);
937}
938
939/**************************** RSExportRecordType ****************************/
940RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
941                                               const clang::RecordType *RT,
942                                               const llvm::StringRef &TypeName,
943                                               bool mIsArtificial) {
944  assert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
945
946  const clang::RecordDecl *RD = RT->getDecl();
947  assert(RD->isStruct());
948
949  RD = RD->getDefinition();
950  if (RD == NULL) {
951    // TODO(zonr): warn that actual struct definition isn't declared in this
952    //             moudle.
953    fprintf(stderr, "RSExportRecordType::Create : this struct is not defined "
954                    "in this module.");
955    return NULL;
956  }
957
958  // Struct layout construct by clang. We rely on this for obtaining the
959  // alloc size of a struct and offset of every field in that struct.
960  const clang::ASTRecordLayout *RL =
961      &Context->getASTContext()->getASTRecordLayout(RD);
962  assert((RL != NULL) && "Failed to retrieve the struct layout from Clang.");
963
964  RSExportRecordType *ERT =
965      new RSExportRecordType(Context,
966                             TypeName,
967                             RD->hasAttr<clang::PackedAttr>(),
968                             mIsArtificial,
969                             (RL->getSize() >> 3));
970  unsigned int Index = 0;
971
972  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
973           FE = RD->field_end();
974       FI != FE;
975       FI++, Index++) {
976#define FAILED_CREATE_FIELD(err)    do {         \
977      if (*err)                                                          \
978        fprintf(stderr, \
979                "RSExportRecordType::Create : failed to create field (%s)\n", \
980                err);                                                   \
981      delete ERT;                                                       \
982      return NULL;                                                      \
983    } while (false)
984
985    // FIXME: All fields should be primitive type
986    assert((*FI)->getKind() == clang::Decl::Field);
987    clang::FieldDecl *FD = *FI;
988
989    // We don't support bit field
990    //
991    // TODO(zonr): allow bitfield with size 8, 16, 32
992    if (FD->isBitField())
993      FAILED_CREATE_FIELD("bit field is not supported");
994
995    // Type
996    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
997
998    if (ET != NULL)
999      ERT->mFields.push_back(
1000          new Field(ET, FD->getName(), ERT,
1001                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1002    else
1003      FAILED_CREATE_FIELD(FD->getName().str().c_str());
1004#undef FAILED_CREATE_FIELD
1005  }
1006
1007  return ERT;
1008}
1009
1010RSExportType::ExportClass RSExportRecordType::getClass() const {
1011  return RSExportType::ExportClassRecord;
1012}
1013
1014const llvm::Type *RSExportRecordType::convertToLLVMType() const {
1015  std::vector<const llvm::Type*> FieldTypes;
1016
1017  for (const_field_iterator FI = fields_begin(),
1018           FE = fields_end();
1019       FI != FE;
1020       FI++) {
1021    const Field *F = *FI;
1022    const RSExportType *FET = F->getType();
1023
1024    FieldTypes.push_back(FET->getLLVMType());
1025  }
1026
1027  return llvm::StructType::get(getRSContext()->getLLVMContext(),
1028                               FieldTypes,
1029                               mIsPacked);
1030}
1031