slang_rs_export_type.cpp revision e639eb5caa2c386b4a60659a4929e8a6141a2cbe
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/RecordLayout.h"
23
24#include "llvm/ADT/StringExtras.h"
25
26#include "llvm/DerivedTypes.h"
27
28#include "llvm/Target/TargetData.h"
29
30#include "llvm/Type.h"
31
32#include "slang_rs_context.h"
33#include "slang_rs_export_element.h"
34#include "slang_rs_type_spec.h"
35
36#define CHECK_PARENT_EQUALITY(ParentClass, E) \
37  if (!ParentClass::equals(E))                \
38    return false;
39
40namespace slang {
41
42/****************************** RSExportType ******************************/
43bool RSExportType::NormalizeType(const clang::Type *&T,
44                                 llvm::StringRef &TypeName) {
45  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
46      llvm::SmallPtrSet<const clang::Type*, 8>();
47
48  if ((T = RSExportType::TypeExportable(T, SPS)) == NULL)
49    // TODO(zonr): warn that type not exportable.
50    return false;
51
52  // Get type name
53  TypeName = RSExportType::GetTypeName(T);
54  if (TypeName.empty())
55    // TODO(zonr): warning that the type is unnamed.
56    return false;
57
58  return true;
59}
60
61const clang::Type
62*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
63  if (DD) {
64    clang::QualType T;
65    if (DD->getTypeSourceInfo())
66      T = DD->getTypeSourceInfo()->getType();
67    else
68      T = DD->getType();
69
70    if (T.isNull())
71      return NULL;
72    else
73      return T.getTypePtr();
74  }
75  return NULL;
76}
77
78llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
79  T = GET_CANONICAL_TYPE(T);
80  if (T == NULL)
81    return llvm::StringRef();
82
83  switch (T->getTypeClass()) {
84    case clang::Type::Builtin: {
85      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
86
87      switch (BT->getKind()) {
88#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
89        case builtin_type:                                    \
90          return cname;                                       \
91        break;
92#include "RSClangBuiltinEnums.inc"
93        default: {
94          assert(false && "Unknown data type of the builtin");
95          break;
96        }
97      }
98      break;
99    }
100    case clang::Type::Record: {
101      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
102      llvm::StringRef Name = RD->getName();
103      if (Name.empty()) {
104          if (RD->getTypedefForAnonDecl() != NULL)
105            Name = RD->getTypedefForAnonDecl()->getName();
106
107          if (Name.empty())
108            // Try to find a name from redeclaration (i.e. typedef)
109            for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
110                     RE = RD->redecls_end();
111                 RI != RE;
112                 RI++) {
113              assert(*RI != NULL && "cannot be NULL object");
114
115              Name = (*RI)->getName();
116              if (!Name.empty())
117                break;
118            }
119      }
120      return Name;
121    }
122    case clang::Type::Pointer: {
123      // "*" plus pointee name
124      const clang::Type *PT = GET_POINTEE_TYPE(T);
125      llvm::StringRef PointeeName;
126      if (NormalizeType(PT, PointeeName)) {
127        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
128        Name[0] = '*';
129        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
130        Name[PointeeName.size() + 1] = '\0';
131        return Name;
132      }
133      break;
134    }
135    case clang::Type::ExtVector: {
136      const clang::ExtVectorType *EVT =
137          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
138      return RSExportVectorType::GetTypeName(EVT);
139      break;
140    }
141    case clang::Type::ConstantArray : {
142      // Construct name for a constant array is too complicated.
143      return DUMMY_TYPE_NAME_FOR_RS_CONSTANT_ARRAY_TYPE;
144    }
145    default: {
146      break;
147    }
148  }
149
150  return llvm::StringRef();
151}
152
153const clang::Type *RSExportType::TypeExportable(
154    const clang::Type *T,
155    llvm::SmallPtrSet<const clang::Type*, 8>& SPS) {
156  // Normalize first
157  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
158    return NULL;
159
160  if (SPS.count(T))
161    return T;
162
163  switch (T->getTypeClass()) {
164    case clang::Type::Builtin: {
165      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
166
167      switch (BT->getKind()) {
168#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
169        case builtin_type:
170#include "RSClangBuiltinEnums.inc"
171          return T;
172        default: {
173          return NULL;
174        }
175      }
176    }
177    case clang::Type::Record: {
178      if (RSExportPrimitiveType::GetRSSpecificType(T) !=
179          RSExportPrimitiveType::DataTypeUnknown)
180        return T;  // RS object type, no further checks are needed
181
182      // Check internal struct
183      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
184      if (RD != NULL)
185        RD = RD->getDefinition();
186
187      // Fast check
188      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
189        return NULL;
190
191      // Insert myself into checking set
192      SPS.insert(T);
193
194      // Check all element
195      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
196               FE = RD->field_end();
197           FI != FE;
198           FI++) {
199        const clang::FieldDecl *FD = *FI;
200        const clang::Type *FT = GetTypeOfDecl(FD);
201        FT = GET_CANONICAL_TYPE(FT);
202
203        if (!TypeExportable(FT, SPS)) {
204          fprintf(stderr, "Field `%s' in Record `%s' contains unsupported "
205                          "type\n", FD->getNameAsString().c_str(),
206                                    RD->getNameAsString().c_str());
207          FT->dump();
208          return NULL;
209        }
210      }
211
212      return T;
213    }
214    case clang::Type::Pointer: {
215      const clang::PointerType *PT = UNSAFE_CAST_TYPE(clang::PointerType, T);
216      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
217
218      if (PointeeType->getTypeClass() == clang::Type::Pointer)
219        return T;
220      // We don't support pointer with array-type pointee or unsupported pointee
221      // type
222      if (PointeeType->isArrayType() ||
223         (TypeExportable(PointeeType, SPS) == NULL) )
224        return NULL;
225      else
226        return T;
227    }
228    case clang::Type::ExtVector: {
229      const clang::ExtVectorType *EVT =
230          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
231      // Only vector with size 2, 3 and 4 are supported.
232      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
233        return NULL;
234
235      // Check base element type
236      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
237
238      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
239          (TypeExportable(ElementType, SPS) == NULL))
240        return NULL;
241      else
242        return T;
243    }
244    case clang::Type::ConstantArray: {
245      const clang::ConstantArrayType *CAT =
246          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
247
248      // Check size
249      if (CAT->getSize().getActiveBits() > 32) {
250        fprintf(stderr, "RSExportConstantArrayType::Create : array with too "
251                        "large size (> 2^32).\n");
252        return NULL;
253      }
254      // Check element type
255      const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
256      if (ElementType->isArrayType()) {
257        fprintf(stderr, "RSExportType::TypeExportable : constant array with 2 "
258                        "or higher dimension of constant is not supported.\n");
259        return NULL;
260      }
261      if (TypeExportable(ElementType, SPS) == NULL)
262        return NULL;
263      else
264        return T;
265    }
266    default: {
267      return NULL;
268    }
269  }
270}
271
272RSExportType *RSExportType::Create(RSContext *Context,
273                                   const clang::Type *T,
274                                   const llvm::StringRef &TypeName) {
275  // Lookup the context to see whether the type was processed before.
276  // Newly created RSExportType will insert into context
277  // in RSExportType::RSExportType()
278  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
279
280  if (ETI != Context->export_types_end())
281    return ETI->second;
282
283  RSExportType *ET = NULL;
284  switch (T->getTypeClass()) {
285    case clang::Type::Record: {
286      RSExportPrimitiveType::DataType dt =
287          RSExportPrimitiveType::GetRSSpecificType(TypeName);
288      switch (dt) {
289        case RSExportPrimitiveType::DataTypeUnknown: {
290          // User-defined types
291          ET = RSExportRecordType::Create(Context,
292                                          T->getAsStructureType(),
293                                          TypeName);
294          break;
295        }
296        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
297          // 2 x 2 Matrix type
298          ET = RSExportMatrixType::Create(Context,
299                                          T->getAsStructureType(),
300                                          TypeName,
301                                          2);
302          break;
303        }
304        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
305          // 3 x 3 Matrix type
306          ET = RSExportMatrixType::Create(Context,
307                                          T->getAsStructureType(),
308                                          TypeName,
309                                          3);
310          break;
311        }
312        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
313          // 4 x 4 Matrix type
314          ET = RSExportMatrixType::Create(Context,
315                                          T->getAsStructureType(),
316                                          TypeName,
317                                          4);
318          break;
319        }
320        default: {
321          // Others are primitive types
322          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
323          break;
324        }
325      }
326      break;
327    }
328    case clang::Type::Builtin: {
329      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
330      break;
331    }
332    case clang::Type::Pointer: {
333      ET = RSExportPointerType::Create(Context,
334                                       UNSAFE_CAST_TYPE(clang::PointerType, T),
335                                       TypeName);
336      // FIXME: free the name (allocated in RSExportType::GetTypeName)
337      delete [] TypeName.data();
338      break;
339    }
340    case clang::Type::ExtVector: {
341      ET = RSExportVectorType::Create(Context,
342                                      UNSAFE_CAST_TYPE(clang::ExtVectorType, T),
343                                      TypeName);
344      break;
345    }
346    case clang::Type::ConstantArray: {
347      ET = RSExportConstantArrayType::Create(
348              Context,
349              UNSAFE_CAST_TYPE(clang::ConstantArrayType, T));
350      break;
351    }
352    default: {
353      // TODO(zonr): warn that type is not exportable.
354      fprintf(stderr,
355              "RSExportType::Create : type '%s' is not exportable\n",
356              T->getTypeClassName());
357      break;
358    }
359  }
360
361  return ET;
362}
363
364RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
365  llvm::StringRef TypeName;
366  if (NormalizeType(T, TypeName))
367    return Create(Context, T, TypeName);
368  else
369    return NULL;
370}
371
372RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
373                                           const clang::VarDecl *VD) {
374  return RSExportType::Create(Context, GetTypeOfDecl(VD));
375}
376
377size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
378  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
379      ET->getLLVMType());
380}
381
382size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
383  if (ET->getClass() == RSExportType::ExportClassRecord)
384    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
385  else
386    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
387        ET->getLLVMType());
388}
389
390RSExportType::RSExportType(RSContext *Context,
391                           ExportClass Class,
392                           const llvm::StringRef &Name)
393    : RSExportable(Context, RSExportable::EX_TYPE),
394      mClass(Class),
395      // Make a copy on Name since memory stored @Name is either allocated in
396      // ASTContext or allocated in GetTypeName which will be destroyed later.
397      mName(Name.data(), Name.size()),
398      mLLVMType(NULL),
399      mSpecType(NULL) {
400  // Don't cache the type whose name start with '<'. Those type failed to
401  // get their name since constructing their name in GetTypeName() requiring
402  // complicated work.
403  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
404    // TODO(zonr): Need to check whether the insertion is successful or not.
405    Context->insertExportType(llvm::StringRef(Name), this);
406  return;
407}
408
409bool RSExportType::keep() {
410  if (!RSExportable::keep())
411    return false;
412  // Invalidate converted LLVM type.
413  mLLVMType = NULL;
414  return true;
415}
416
417bool RSExportType::equals(const RSExportable *E) const {
418  CHECK_PARENT_EQUALITY(RSExportable, E);
419  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
420}
421
422RSExportType::~RSExportType() {
423  delete mSpecType;
424}
425
426/************************** RSExportPrimitiveType **************************/
427llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
428RSExportPrimitiveType::RSSpecificTypeMap;
429
430llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
431
432bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
433  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
434    return true;
435  else
436    return false;
437}
438
439RSExportPrimitiveType::DataType
440RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
441  if (TypeName.empty())
442    return DataTypeUnknown;
443
444  if (RSSpecificTypeMap->empty()) {
445#define ENUM_RS_MATRIX_TYPE(type, cname, dim)                       \
446    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
447#include "RSMatrixTypeEnums.inc"
448#define ENUM_RS_OBJECT_TYPE(type, cname)                            \
449    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
450#include "RSObjectTypeEnums.inc"
451  }
452
453  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
454  if (I == RSSpecificTypeMap->end())
455    return DataTypeUnknown;
456  else
457    return I->getValue();
458}
459
460RSExportPrimitiveType::DataType
461RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
462  T = GET_CANONICAL_TYPE(T);
463  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
464    return DataTypeUnknown;
465
466  return GetRSSpecificType( RSExportType::GetTypeName(T) );
467}
468
469bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
470  return ((DT >= FirstRSMatrixType) && (DT <= LastRSMatrixType));
471}
472
473bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
474  return ((DT >= FirstRSObjectType) && (DT <= LastRSObjectType));
475}
476
477const size_t RSExportPrimitiveType::SizeOfDataTypeInBits[] = {
478#define ENUM_RS_DATA_TYPE(type, cname, bits)  \
479  bits,
480#include "RSDataTypeEnums.inc"
481  0   // DataTypeMax
482};
483
484size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
485  assert(((EPT->getType() > DataTypeUnknown) &&
486          (EPT->getType() < DataTypeMax)) &&
487         "RSExportPrimitiveType::GetSizeInBits : unknown data type");
488  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
489}
490
491RSExportPrimitiveType::DataType
492RSExportPrimitiveType::GetDataType(const clang::Type *T) {
493  if (T == NULL)
494    return DataTypeUnknown;
495
496  switch (T->getTypeClass()) {
497    case clang::Type::Builtin: {
498      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
499      switch (BT->getKind()) {
500#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
501        case builtin_type: {                                  \
502          return DataType ## type;                            \
503        }
504#include "RSClangBuiltinEnums.inc"
505        // The size of type WChar depend on platform so we abandon the support
506        // to them.
507        default: {
508          fprintf(stderr, "RSExportPrimitiveType::GetDataType : unsupported "
509                          "built-in type '%s'\n.", T->getTypeClassName());
510          break;
511        }
512      }
513      break;
514    }
515    case clang::Type::Record: {
516      // must be RS object type
517      return RSExportPrimitiveType::GetRSSpecificType(T);
518    }
519    default: {
520      fprintf(stderr, "RSExportPrimitiveType::GetDataType : type '%s' is not "
521                      "supported primitive type\n", T->getTypeClassName());
522      break;
523    }
524  }
525
526  return DataTypeUnknown;
527}
528
529RSExportPrimitiveType
530*RSExportPrimitiveType::Create(RSContext *Context,
531                               const clang::Type *T,
532                               const llvm::StringRef &TypeName,
533                               DataKind DK,
534                               bool Normalized) {
535  DataType DT = GetDataType(T);
536
537  if ((DT == DataTypeUnknown) || TypeName.empty())
538    return NULL;
539  else
540    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
541                                     DT, DK, Normalized);
542}
543
544RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
545                                                     const clang::Type *T,
546                                                     DataKind DK) {
547  llvm::StringRef TypeName;
548  if (RSExportType::NormalizeType(T, TypeName) && IsPrimitiveType(T))
549    return Create(Context, T, TypeName, DK);
550  else
551    return NULL;
552}
553
554const llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
555  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
556
557  if (isRSObjectType()) {
558    // struct {
559    //   int *p;
560    // } __attribute__((packed, aligned(pointer_size)))
561    //
562    // which is
563    //
564    // <{ [1 x i32] }> in LLVM
565    //
566    if (RSObjectLLVMType == NULL) {
567      std::vector<const llvm::Type *> Elements;
568      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
569      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
570    }
571    return RSObjectLLVMType;
572  }
573
574  switch (mType) {
575    case DataTypeFloat32: {
576      return llvm::Type::getFloatTy(C);
577      break;
578    }
579    case DataTypeFloat64: {
580      return llvm::Type::getDoubleTy(C);
581      break;
582    }
583    case DataTypeBoolean: {
584      return llvm::Type::getInt1Ty(C);
585      break;
586    }
587    case DataTypeSigned8:
588    case DataTypeUnsigned8: {
589      return llvm::Type::getInt8Ty(C);
590      break;
591    }
592    case DataTypeSigned16:
593    case DataTypeUnsigned16:
594    case DataTypeUnsigned565:
595    case DataTypeUnsigned5551:
596    case DataTypeUnsigned4444: {
597      return llvm::Type::getInt16Ty(C);
598      break;
599    }
600    case DataTypeSigned32:
601    case DataTypeUnsigned32: {
602      return llvm::Type::getInt32Ty(C);
603      break;
604    }
605    case DataTypeSigned64:
606    case DataTypeUnsigned64: {
607      return llvm::Type::getInt64Ty(C);
608      break;
609    }
610    default: {
611      assert(false && "Unknown data type");
612    }
613  }
614
615  return NULL;
616}
617
618union RSType *RSExportPrimitiveType::convertToSpecType() const {
619  llvm::OwningPtr<union RSType> ST(new union RSType);
620  RS_TYPE_SET_CLASS(ST, RS_TC_Primitive);
621  // enum RSExportPrimitiveType::DataType is synced with enum RSDataType in
622  // slang_rs_type_spec.h
623  RS_PRIMITIVE_TYPE_SET_DATA_TYPE(ST, getType());
624  return ST.take();
625}
626
627bool RSExportPrimitiveType::equals(const RSExportable *E) const {
628  CHECK_PARENT_EQUALITY(RSExportType, E);
629  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
630}
631
632/**************************** RSExportPointerType ****************************/
633
634const clang::Type *RSExportPointerType::IntegerType = NULL;
635
636RSExportPointerType
637*RSExportPointerType::Create(RSContext *Context,
638                             const clang::PointerType *PT,
639                             const llvm::StringRef &TypeName) {
640  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
641  const RSExportType *PointeeET;
642
643  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
644    PointeeET = RSExportType::Create(Context, PointeeType);
645  } else {
646    // Double or higher dimension of pointer, export as int*
647    assert(IntegerType != NULL && "Built-in integer type is not set");
648    PointeeET = RSExportPrimitiveType::Create(Context, IntegerType);
649  }
650
651  if (PointeeET == NULL) {
652    fprintf(stderr, "Failed to create type for pointee");
653    return NULL;
654  }
655
656  return new RSExportPointerType(Context, TypeName, PointeeET);
657}
658
659const llvm::Type *RSExportPointerType::convertToLLVMType() const {
660  const llvm::Type *PointeeType = mPointeeType->getLLVMType();
661  return llvm::PointerType::getUnqual(PointeeType);
662}
663
664union RSType *RSExportPointerType::convertToSpecType() const {
665  llvm::OwningPtr<union RSType> ST(new union RSType);
666
667  RS_TYPE_SET_CLASS(ST, RS_TC_Pointer);
668  RS_POINTER_TYPE_SET_POINTEE_TYPE(ST, getPointeeType()->getSpecType());
669
670  if (RS_POINTER_TYPE_GET_POINTEE_TYPE(ST) != NULL)
671    return ST.take();
672  else
673    return NULL;
674}
675
676bool RSExportPointerType::keep() {
677  if (!RSExportType::keep())
678    return false;
679  const_cast<RSExportType*>(mPointeeType)->keep();
680  return true;
681}
682
683bool RSExportPointerType::equals(const RSExportable *E) const {
684  CHECK_PARENT_EQUALITY(RSExportType, E);
685  return (static_cast<const RSExportPointerType*>(E)
686              ->getPointeeType()->equals(getPointeeType()));
687}
688
689/***************************** RSExportVectorType *****************************/
690llvm::StringRef
691RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
692  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
693
694  if ((ElementType->getTypeClass() != clang::Type::Builtin))
695    return llvm::StringRef();
696
697  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType,
698                                                  ElementType);
699  if ((EVT->getNumElements() < 1) ||
700      (EVT->getNumElements() > 4))
701    return llvm::StringRef();
702
703  switch (BT->getKind()) {
704    // Compiler is smart enough to optimize following *big if branches* since
705    // they all become "constant comparison" after macro expansion
706#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
707    case builtin_type: {                                      \
708      const char *Name[] = { cname"2", cname"3", cname"4" };  \
709      return Name[EVT->getNumElements() - 2];                 \
710      break;                                                  \
711    }
712#include "RSClangBuiltinEnums.inc"
713    default: {
714      return llvm::StringRef();
715    }
716  }
717}
718
719RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
720                                               const clang::ExtVectorType *EVT,
721                                               const llvm::StringRef &TypeName,
722                                               DataKind DK,
723                                               bool Normalized) {
724  assert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
725
726  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
727  RSExportPrimitiveType::DataType DT =
728      RSExportPrimitiveType::GetDataType(ElementType);
729
730  if (DT != RSExportPrimitiveType::DataTypeUnknown)
731    return new RSExportVectorType(Context,
732                                  TypeName,
733                                  DT,
734                                  DK,
735                                  Normalized,
736                                  EVT->getNumElements());
737  else
738    fprintf(stderr, "RSExportVectorType::Create : unsupported base element "
739                    "type\n");
740  return NULL;
741}
742
743const llvm::Type *RSExportVectorType::convertToLLVMType() const {
744  const llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
745  return llvm::VectorType::get(ElementType, getNumElement());
746}
747
748union RSType *RSExportVectorType::convertToSpecType() const {
749  llvm::OwningPtr<union RSType> ST(new union RSType);
750
751  RS_TYPE_SET_CLASS(ST, RS_TC_Vector);
752  RS_VECTOR_TYPE_SET_ELEMENT_TYPE(ST, getType());
753  RS_VECTOR_TYPE_SET_VECTOR_SIZE(ST, getNumElement());
754
755  return ST.take();
756}
757
758bool RSExportVectorType::equals(const RSExportable *E) const {
759  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
760  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
761              == getNumElement());
762}
763
764/***************************** RSExportMatrixType *****************************/
765RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
766                                               const clang::RecordType *RT,
767                                               const llvm::StringRef &TypeName,
768                                               unsigned Dim) {
769  assert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
770  assert((Dim > 1) && "Invalid dimension of matrix");
771
772  // Check whether the struct rs_matrix is in our expected form (but assume it's
773  // correct if we're not sure whether it's correct or not)
774  const clang::RecordDecl* RD = RT->getDecl();
775  RD = RD->getDefinition();
776  if (RD != NULL) {
777    // Find definition, perform further examination
778    if (RD->field_empty()) {
779      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
780                      "must have 1 field for saving values", TypeName.data());
781      return NULL;
782    }
783
784    clang::RecordDecl::field_iterator FIT = RD->field_begin();
785    const clang::FieldDecl *FD = *FIT;
786    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
787    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
788      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
789                      "first field should be an array with constant size",
790              TypeName.data());
791      return NULL;
792    }
793    const clang::ConstantArrayType *CAT =
794      static_cast<const clang::ConstantArrayType *>(FT);
795    const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
796    if ((ElementType == NULL) ||
797        (ElementType->getTypeClass() != clang::Type::Builtin) ||
798        (static_cast<const clang::BuiltinType *>(ElementType)->getKind()
799          != clang::BuiltinType::Float)) {
800      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
801                      "first field should be a float array", TypeName.data());
802      return NULL;
803    }
804
805    if (CAT->getSize() != Dim * Dim) {
806      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
807                      "first field should be an array with size %d",
808              TypeName.data(), Dim * Dim);
809      return NULL;
810    }
811
812    FIT++;
813    if (FIT != RD->field_end()) {
814      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
815                      "must have exactly 1 field", TypeName.data());
816      return NULL;
817    }
818  }
819
820  return new RSExportMatrixType(Context, TypeName, Dim);
821}
822
823const llvm::Type *RSExportMatrixType::convertToLLVMType() const {
824  // Construct LLVM type:
825  // struct {
826  //  float X[mDim * mDim];
827  // }
828
829  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
830  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
831                                            mDim * mDim);
832  return llvm::StructType::get(C, X, NULL);
833}
834
835union RSType *RSExportMatrixType::convertToSpecType() const {
836  llvm::OwningPtr<union RSType> ST(new union RSType);
837  RS_TYPE_SET_CLASS(ST, RS_TC_Matrix);
838  switch (getDim()) {
839    case 2: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix2x2); break;
840    case 3: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix3x3); break;
841    case 4: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix4x4); break;
842    default: assert(false && "Matrix type with unsupported dimension.");
843  }
844  return ST.take();
845}
846
847bool RSExportMatrixType::equals(const RSExportable *E) const {
848  CHECK_PARENT_EQUALITY(RSExportType, E);
849  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
850}
851
852/************************* RSExportConstantArrayType *************************/
853RSExportConstantArrayType
854*RSExportConstantArrayType::Create(RSContext *Context,
855                                   const clang::ConstantArrayType *CAT) {
856  assert(CAT != NULL && CAT->getTypeClass() == clang::Type::ConstantArray);
857
858  assert((CAT->getSize().getActiveBits() < 32) && "array too large");
859
860  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
861  assert((Size > 0) && "Constant array should have size greater than 0");
862
863  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
864  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
865
866  if (ElementET == NULL) {
867    fprintf(stderr, "RSExportConstantArrayType::Create : failed to create "
868                    "RSExportType for array element.\n");
869    return NULL;
870  }
871
872  return new RSExportConstantArrayType(Context,
873                                       ElementET,
874                                       Size);
875}
876
877const llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
878  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
879}
880
881union RSType *RSExportConstantArrayType::convertToSpecType() const {
882  llvm::OwningPtr<union RSType> ST(new union RSType);
883
884  RS_TYPE_SET_CLASS(ST, RS_TC_ConstantArray);
885  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_TYPE(
886      ST, getElementType()->getSpecType());
887  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_SIZE(ST, getSize());
888
889  if (RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(ST) != NULL)
890    return ST.take();
891  else
892    return NULL;
893}
894
895bool RSExportConstantArrayType::keep() {
896  if (!RSExportType::keep())
897    return false;
898  const_cast<RSExportType*>(mElementType)->keep();
899  return true;
900}
901
902bool RSExportConstantArrayType::equals(const RSExportable *E) const {
903  CHECK_PARENT_EQUALITY(RSExportType, E);
904  return ((static_cast<const RSExportConstantArrayType*>(E)
905              ->getSize() == getSize()) && (mElementType->equals(E)));
906}
907
908/**************************** RSExportRecordType ****************************/
909RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
910                                               const clang::RecordType *RT,
911                                               const llvm::StringRef &TypeName,
912                                               bool mIsArtificial) {
913  assert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
914
915  const clang::RecordDecl *RD = RT->getDecl();
916  assert(RD->isStruct());
917
918  RD = RD->getDefinition();
919  if (RD == NULL) {
920    // TODO(zonr): warn that actual struct definition isn't declared in this
921    //             moudle.
922    fprintf(stderr, "RSExportRecordType::Create : this struct is not defined "
923                    "in this module.");
924    return NULL;
925  }
926
927  // Struct layout construct by clang. We rely on this for obtaining the
928  // alloc size of a struct and offset of every field in that struct.
929  const clang::ASTRecordLayout *RL =
930      &Context->getASTContext().getASTRecordLayout(RD);
931  assert((RL != NULL) && "Failed to retrieve the struct layout from Clang.");
932
933  RSExportRecordType *ERT =
934      new RSExportRecordType(Context,
935                             TypeName,
936                             RD->hasAttr<clang::PackedAttr>(),
937                             mIsArtificial,
938                             (RL->getSize() >> 3));
939  unsigned int Index = 0;
940
941  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
942           FE = RD->field_end();
943       FI != FE;
944       FI++, Index++) {
945#define FAILED_CREATE_FIELD(err)    do {         \
946      if (*err)                                                          \
947        fprintf(stderr, \
948                "RSExportRecordType::Create : failed to create field (%s)\n", \
949                err);                                                   \
950      delete ERT;                                                       \
951      return NULL;                                                      \
952    } while (false)
953
954    // FIXME: All fields should be primitive type
955    assert((*FI)->getKind() == clang::Decl::Field);
956    clang::FieldDecl *FD = *FI;
957
958    // We don't support bit field
959    //
960    // TODO(zonr): allow bitfield with size 8, 16, 32
961    if (FD->isBitField())
962      FAILED_CREATE_FIELD("bit field is not supported");
963
964    // Type
965    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
966
967    if (ET != NULL)
968      ERT->mFields.push_back(
969          new Field(ET, FD->getName(), ERT,
970                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
971    else
972      FAILED_CREATE_FIELD(FD->getName().str().c_str());
973#undef FAILED_CREATE_FIELD
974  }
975
976  return ERT;
977}
978
979const llvm::Type *RSExportRecordType::convertToLLVMType() const {
980  // Create an opaque type since struct may reference itself recursively.
981  llvm::PATypeHolder ResultHolder =
982      llvm::OpaqueType::get(getRSContext()->getLLVMContext());
983  setAbstractLLVMType(ResultHolder.get());
984
985  std::vector<const llvm::Type*> FieldTypes;
986
987  for (const_field_iterator FI = fields_begin(), FE = fields_end();
988       FI != FE;
989       FI++) {
990    const Field *F = *FI;
991    const RSExportType *FET = F->getType();
992
993    FieldTypes.push_back(FET->getLLVMType());
994  }
995
996  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
997                                               FieldTypes,
998                                               mIsPacked);
999  if (ST != NULL)
1000    static_cast<llvm::OpaqueType*>(ResultHolder.get())
1001        ->refineAbstractTypeTo(ST);
1002  else
1003    return NULL;
1004  return ResultHolder.get();
1005}
1006
1007union RSType *RSExportRecordType::convertToSpecType() const {
1008  unsigned NumFields = getFields().size();
1009  unsigned AllocSize = sizeof(union RSType) +
1010                       sizeof(struct RSRecordField) * NumFields;
1011  llvm::OwningPtr<union RSType> ST(
1012      reinterpret_cast<union RSType*>(operator new(AllocSize)));
1013
1014  ::memset(ST.get(), 0, AllocSize);
1015
1016  RS_TYPE_SET_CLASS(ST, RS_TC_Record);
1017  RS_RECORD_TYPE_SET_NAME(ST, getName().c_str());
1018  RS_RECORD_TYPE_SET_NUM_FIELDS(ST, NumFields);
1019
1020  setSpecTypeTemporarily(ST.get());
1021
1022  unsigned FieldIdx = 0;
1023  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1024       FI != FE;
1025       FI++, FieldIdx++) {
1026    const Field *F = *FI;
1027
1028    RS_RECORD_TYPE_SET_FIELD_NAME(ST, FieldIdx, F->getName().c_str());
1029    RS_RECORD_TYPE_SET_FIELD_TYPE(ST, FieldIdx, F->getType()->getSpecType());
1030
1031    enum RSDataKind DK = RS_DK_User;
1032    if ((F->getType()->getClass() == ExportClassPrimitive) ||
1033        (F->getType()->getClass() == ExportClassVector)) {
1034      const RSExportPrimitiveType *EPT =
1035        static_cast<const RSExportPrimitiveType*>(F->getType());
1036      // enum RSExportPrimitiveType::DataKind is synced with enum RSDataKind in
1037      // slang_rs_type_spec.h
1038      DK = static_cast<enum RSDataKind>(EPT->getKind());
1039    }
1040    RS_RECORD_TYPE_SET_FIELD_DATA_KIND(ST, FieldIdx, DK);
1041  }
1042
1043  // TODO(slang): Check whether all fields were created normally.
1044
1045  return ST.take();
1046}
1047
1048bool RSExportRecordType::keep() {
1049  if (!RSExportType::keep())
1050    return false;
1051  for (std::list<const Field*>::iterator I = mFields.begin(),
1052          E = mFields.end();
1053       I != E;
1054       I++) {
1055    const_cast<RSExportType*>((*I)->getType())->keep();
1056  }
1057  return true;
1058}
1059
1060bool RSExportRecordType::equals(const RSExportable *E) const {
1061  CHECK_PARENT_EQUALITY(RSExportType, E);
1062
1063  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1064
1065  if (ERT->getFields().size() != getFields().size())
1066    return false;
1067
1068  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1069
1070  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1071    if (!(*AI)->getType()->equals((*BI)->getType()))
1072      return false;
1073    AI++;
1074    BI++;
1075  }
1076
1077  return true;
1078}
1079
1080}  // namespace slang
1081