slang_rs_export_type.cpp revision c808a99831115928b4648f4c8b86dc682594217a
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/RecordLayout.h"
23
24#include "llvm/ADT/StringExtras.h"
25
26#include "llvm/DerivedTypes.h"
27
28#include "llvm/Target/TargetData.h"
29
30#include "llvm/Type.h"
31
32#include "slang_rs_context.h"
33#include "slang_rs_export_element.h"
34#include "slang_rs_type_spec.h"
35
36#define CHECK_PARENT_EQUALITY(ParentClass, E) \
37  if (!ParentClass::equals(E))                \
38    return false;
39
40namespace slang {
41
42/****************************** RSExportType ******************************/
43bool RSExportType::NormalizeType(const clang::Type *&T,
44                                 llvm::StringRef &TypeName) {
45  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
46      llvm::SmallPtrSet<const clang::Type*, 8>();
47
48  if ((T = RSExportType::TypeExportable(T, SPS, false)) == NULL)
49    // TODO(zonr): warn that type not exportable.
50    return false;
51
52  // Get type name
53  TypeName = RSExportType::GetTypeName(T);
54  if (TypeName.empty())
55    // TODO(zonr): warning that the type is unnamed.
56    return false;
57
58  return true;
59}
60
61const clang::Type
62*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
63  if (DD) {
64    clang::QualType T;
65    if (DD->getTypeSourceInfo())
66      T = DD->getTypeSourceInfo()->getType();
67    else
68      T = DD->getType();
69
70    if (T.isNull())
71      return NULL;
72    else
73      return T.getTypePtr();
74  }
75  return NULL;
76}
77
78llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
79  T = GET_CANONICAL_TYPE(T);
80  if (T == NULL)
81    return llvm::StringRef();
82
83  switch (T->getTypeClass()) {
84    case clang::Type::Builtin: {
85      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
86
87      switch (BT->getKind()) {
88#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
89        case builtin_type:                                    \
90          return cname;                                       \
91        break;
92#include "RSClangBuiltinEnums.inc"
93        default: {
94          assert(false && "Unknown data type of the builtin");
95          break;
96        }
97      }
98      break;
99    }
100    case clang::Type::Record: {
101      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
102      llvm::StringRef Name = RD->getName();
103      if (Name.empty()) {
104          if (RD->getTypedefForAnonDecl() != NULL)
105            Name = RD->getTypedefForAnonDecl()->getName();
106
107          if (Name.empty())
108            // Try to find a name from redeclaration (i.e. typedef)
109            for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
110                     RE = RD->redecls_end();
111                 RI != RE;
112                 RI++) {
113              assert(*RI != NULL && "cannot be NULL object");
114
115              Name = (*RI)->getName();
116              if (!Name.empty())
117                break;
118            }
119      }
120      return Name;
121    }
122    case clang::Type::Pointer: {
123      // "*" plus pointee name
124      const clang::Type *PT = GET_POINTEE_TYPE(T);
125      llvm::StringRef PointeeName;
126      if (NormalizeType(PT, PointeeName)) {
127        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
128        Name[0] = '*';
129        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
130        Name[PointeeName.size() + 1] = '\0';
131        return Name;
132      }
133      break;
134    }
135    case clang::Type::ExtVector: {
136      const clang::ExtVectorType *EVT =
137          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
138      return RSExportVectorType::GetTypeName(EVT);
139      break;
140    }
141    case clang::Type::ConstantArray : {
142      // Construct name for a constant array is too complicated.
143      return DUMMY_TYPE_NAME_FOR_RS_CONSTANT_ARRAY_TYPE;
144    }
145    default: {
146      break;
147    }
148  }
149
150  return llvm::StringRef();
151}
152
153const clang::Type *RSExportType::TypeExportable(
154    const clang::Type *T,
155    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
156    bool InRecord) {
157  // Normalize first
158  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
159    return NULL;
160
161  if (SPS.count(T))
162    return T;
163
164  switch (T->getTypeClass()) {
165    case clang::Type::Builtin: {
166      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
167
168      switch (BT->getKind()) {
169#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
170        case builtin_type:
171#include "RSClangBuiltinEnums.inc"
172          return T;
173        default: {
174          return NULL;
175        }
176      }
177    }
178    case clang::Type::Record: {
179      if (RSExportPrimitiveType::GetRSSpecificType(T) !=
180          RSExportPrimitiveType::DataTypeUnknown)
181        return T;  // RS object type, no further checks are needed
182
183      // Check internal struct
184      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
185      if (RD != NULL)
186        RD = RD->getDefinition();
187
188      // Fast check
189      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
190        return NULL;
191
192      // Insert myself into checking set
193      SPS.insert(T);
194
195      // Check all element
196      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
197               FE = RD->field_end();
198           FI != FE;
199           FI++) {
200        const clang::FieldDecl *FD = *FI;
201        const clang::Type *FT = GetTypeOfDecl(FD);
202        FT = GET_CANONICAL_TYPE(FT);
203
204        if (!TypeExportable(FT, SPS, true)) {
205          fprintf(stderr, "Field `%s' in Record `%s' contains unsupported "
206                          "type\n", FD->getNameAsString().c_str(),
207                                    RD->getNameAsString().c_str());
208          FT->dump();
209          return NULL;
210        }
211      }
212
213      return T;
214    }
215    case clang::Type::Pointer: {
216      if (InRecord) {
217        return NULL;
218      }
219      const clang::PointerType *PT = UNSAFE_CAST_TYPE(clang::PointerType, T);
220      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
221
222      if (PointeeType->getTypeClass() == clang::Type::Pointer)
223        return T;
224      // We don't support pointer with array-type pointee or unsupported pointee
225      // type
226      if (PointeeType->isArrayType() ||
227          (TypeExportable(PointeeType, SPS, InRecord) == NULL))
228        return NULL;
229      else
230        return T;
231    }
232    case clang::Type::ExtVector: {
233      const clang::ExtVectorType *EVT =
234          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
235      // Only vector with size 2, 3 and 4 are supported.
236      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
237        return NULL;
238
239      // Check base element type
240      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
241
242      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
243          (TypeExportable(ElementType, SPS, InRecord) == NULL))
244        return NULL;
245      else
246        return T;
247    }
248    case clang::Type::ConstantArray: {
249      const clang::ConstantArrayType *CAT =
250          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
251
252      // Check size
253      if (CAT->getSize().getActiveBits() > 32) {
254        fprintf(stderr, "RSExportConstantArrayType::Create : array with too "
255                        "large size (> 2^32).\n");
256        return NULL;
257      }
258      // Check element type
259      const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
260      if (ElementType->isArrayType()) {
261        fprintf(stderr, "RSExportType::TypeExportable : constant array with 2 "
262                        "or higher dimension of constant is not supported.\n");
263        return NULL;
264      }
265      if (TypeExportable(ElementType, SPS, InRecord) == NULL)
266        return NULL;
267      else
268        return T;
269    }
270    default: {
271      return NULL;
272    }
273  }
274}
275
276RSExportType *RSExportType::Create(RSContext *Context,
277                                   const clang::Type *T,
278                                   const llvm::StringRef &TypeName) {
279  // Lookup the context to see whether the type was processed before.
280  // Newly created RSExportType will insert into context
281  // in RSExportType::RSExportType()
282  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
283
284  if (ETI != Context->export_types_end())
285    return ETI->second;
286
287  RSExportType *ET = NULL;
288  switch (T->getTypeClass()) {
289    case clang::Type::Record: {
290      RSExportPrimitiveType::DataType dt =
291          RSExportPrimitiveType::GetRSSpecificType(TypeName);
292      switch (dt) {
293        case RSExportPrimitiveType::DataTypeUnknown: {
294          // User-defined types
295          ET = RSExportRecordType::Create(Context,
296                                          T->getAsStructureType(),
297                                          TypeName);
298          break;
299        }
300        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
301          // 2 x 2 Matrix type
302          ET = RSExportMatrixType::Create(Context,
303                                          T->getAsStructureType(),
304                                          TypeName,
305                                          2);
306          break;
307        }
308        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
309          // 3 x 3 Matrix type
310          ET = RSExportMatrixType::Create(Context,
311                                          T->getAsStructureType(),
312                                          TypeName,
313                                          3);
314          break;
315        }
316        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
317          // 4 x 4 Matrix type
318          ET = RSExportMatrixType::Create(Context,
319                                          T->getAsStructureType(),
320                                          TypeName,
321                                          4);
322          break;
323        }
324        default: {
325          // Others are primitive types
326          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
327          break;
328        }
329      }
330      break;
331    }
332    case clang::Type::Builtin: {
333      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
334      break;
335    }
336    case clang::Type::Pointer: {
337      ET = RSExportPointerType::Create(Context,
338                                       UNSAFE_CAST_TYPE(clang::PointerType, T),
339                                       TypeName);
340      // FIXME: free the name (allocated in RSExportType::GetTypeName)
341      delete [] TypeName.data();
342      break;
343    }
344    case clang::Type::ExtVector: {
345      ET = RSExportVectorType::Create(Context,
346                                      UNSAFE_CAST_TYPE(clang::ExtVectorType, T),
347                                      TypeName);
348      break;
349    }
350    case clang::Type::ConstantArray: {
351      ET = RSExportConstantArrayType::Create(
352              Context,
353              UNSAFE_CAST_TYPE(clang::ConstantArrayType, T));
354      break;
355    }
356    default: {
357      // TODO(zonr): warn that type is not exportable.
358      fprintf(stderr,
359              "RSExportType::Create : type '%s' is not exportable\n",
360              T->getTypeClassName());
361      break;
362    }
363  }
364
365  return ET;
366}
367
368RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
369  llvm::StringRef TypeName;
370  if (NormalizeType(T, TypeName))
371    return Create(Context, T, TypeName);
372  else
373    return NULL;
374}
375
376RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
377                                           const clang::VarDecl *VD) {
378  return RSExportType::Create(Context, GetTypeOfDecl(VD));
379}
380
381size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
382  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
383      ET->getLLVMType());
384}
385
386size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
387  if (ET->getClass() == RSExportType::ExportClassRecord)
388    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
389  else
390    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
391        ET->getLLVMType());
392}
393
394RSExportType::RSExportType(RSContext *Context,
395                           ExportClass Class,
396                           const llvm::StringRef &Name)
397    : RSExportable(Context, RSExportable::EX_TYPE),
398      mClass(Class),
399      // Make a copy on Name since memory stored @Name is either allocated in
400      // ASTContext or allocated in GetTypeName which will be destroyed later.
401      mName(Name.data(), Name.size()),
402      mLLVMType(NULL),
403      mSpecType(NULL) {
404  // Don't cache the type whose name start with '<'. Those type failed to
405  // get their name since constructing their name in GetTypeName() requiring
406  // complicated work.
407  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
408    // TODO(zonr): Need to check whether the insertion is successful or not.
409    Context->insertExportType(llvm::StringRef(Name), this);
410  return;
411}
412
413bool RSExportType::keep() {
414  if (!RSExportable::keep())
415    return false;
416  // Invalidate converted LLVM type.
417  mLLVMType = NULL;
418  return true;
419}
420
421bool RSExportType::equals(const RSExportable *E) const {
422  CHECK_PARENT_EQUALITY(RSExportable, E);
423  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
424}
425
426RSExportType::~RSExportType() {
427  delete mSpecType;
428}
429
430/************************** RSExportPrimitiveType **************************/
431llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
432RSExportPrimitiveType::RSSpecificTypeMap;
433
434llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
435
436bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
437  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
438    return true;
439  else
440    return false;
441}
442
443RSExportPrimitiveType::DataType
444RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
445  if (TypeName.empty())
446    return DataTypeUnknown;
447
448  if (RSSpecificTypeMap->empty()) {
449#define ENUM_RS_MATRIX_TYPE(type, cname, dim)                       \
450    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
451#include "RSMatrixTypeEnums.inc"
452#define ENUM_RS_OBJECT_TYPE(type, cname)                            \
453    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
454#include "RSObjectTypeEnums.inc"
455  }
456
457  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
458  if (I == RSSpecificTypeMap->end())
459    return DataTypeUnknown;
460  else
461    return I->getValue();
462}
463
464RSExportPrimitiveType::DataType
465RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
466  T = GET_CANONICAL_TYPE(T);
467  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
468    return DataTypeUnknown;
469
470  return GetRSSpecificType( RSExportType::GetTypeName(T) );
471}
472
473bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
474  return ((DT >= FirstRSMatrixType) && (DT <= LastRSMatrixType));
475}
476
477bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
478  return ((DT >= FirstRSObjectType) && (DT <= LastRSObjectType));
479}
480
481const size_t RSExportPrimitiveType::SizeOfDataTypeInBits[] = {
482#define ENUM_RS_DATA_TYPE(type, cname, bits)  \
483  bits,
484#include "RSDataTypeEnums.inc"
485  0   // DataTypeMax
486};
487
488size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
489  assert(((EPT->getType() > DataTypeUnknown) &&
490          (EPT->getType() < DataTypeMax)) &&
491         "RSExportPrimitiveType::GetSizeInBits : unknown data type");
492  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
493}
494
495RSExportPrimitiveType::DataType
496RSExportPrimitiveType::GetDataType(const clang::Type *T) {
497  if (T == NULL)
498    return DataTypeUnknown;
499
500  switch (T->getTypeClass()) {
501    case clang::Type::Builtin: {
502      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
503      switch (BT->getKind()) {
504#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
505        case builtin_type: {                                  \
506          return DataType ## type;                            \
507        }
508#include "RSClangBuiltinEnums.inc"
509        // The size of type WChar depend on platform so we abandon the support
510        // to them.
511        default: {
512          fprintf(stderr, "RSExportPrimitiveType::GetDataType : unsupported "
513                          "built-in type '%s'\n.", T->getTypeClassName());
514          break;
515        }
516      }
517      break;
518    }
519    case clang::Type::Record: {
520      // must be RS object type
521      return RSExportPrimitiveType::GetRSSpecificType(T);
522    }
523    default: {
524      fprintf(stderr, "RSExportPrimitiveType::GetDataType : type '%s' is not "
525                      "supported primitive type\n", T->getTypeClassName());
526      break;
527    }
528  }
529
530  return DataTypeUnknown;
531}
532
533RSExportPrimitiveType
534*RSExportPrimitiveType::Create(RSContext *Context,
535                               const clang::Type *T,
536                               const llvm::StringRef &TypeName,
537                               DataKind DK,
538                               bool Normalized) {
539  DataType DT = GetDataType(T);
540
541  if ((DT == DataTypeUnknown) || TypeName.empty())
542    return NULL;
543  else
544    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
545                                     DT, DK, Normalized);
546}
547
548RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
549                                                     const clang::Type *T,
550                                                     DataKind DK) {
551  llvm::StringRef TypeName;
552  if (RSExportType::NormalizeType(T, TypeName) && IsPrimitiveType(T))
553    return Create(Context, T, TypeName, DK);
554  else
555    return NULL;
556}
557
558const llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
559  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
560
561  if (isRSObjectType()) {
562    // struct {
563    //   int *p;
564    // } __attribute__((packed, aligned(pointer_size)))
565    //
566    // which is
567    //
568    // <{ [1 x i32] }> in LLVM
569    //
570    if (RSObjectLLVMType == NULL) {
571      std::vector<const llvm::Type *> Elements;
572      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
573      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
574    }
575    return RSObjectLLVMType;
576  }
577
578  switch (mType) {
579    case DataTypeFloat32: {
580      return llvm::Type::getFloatTy(C);
581      break;
582    }
583    case DataTypeFloat64: {
584      return llvm::Type::getDoubleTy(C);
585      break;
586    }
587    case DataTypeBoolean: {
588      return llvm::Type::getInt1Ty(C);
589      break;
590    }
591    case DataTypeSigned8:
592    case DataTypeUnsigned8: {
593      return llvm::Type::getInt8Ty(C);
594      break;
595    }
596    case DataTypeSigned16:
597    case DataTypeUnsigned16:
598    case DataTypeUnsigned565:
599    case DataTypeUnsigned5551:
600    case DataTypeUnsigned4444: {
601      return llvm::Type::getInt16Ty(C);
602      break;
603    }
604    case DataTypeSigned32:
605    case DataTypeUnsigned32: {
606      return llvm::Type::getInt32Ty(C);
607      break;
608    }
609    case DataTypeSigned64:
610    case DataTypeUnsigned64: {
611      return llvm::Type::getInt64Ty(C);
612      break;
613    }
614    default: {
615      assert(false && "Unknown data type");
616    }
617  }
618
619  return NULL;
620}
621
622union RSType *RSExportPrimitiveType::convertToSpecType() const {
623  llvm::OwningPtr<union RSType> ST(new union RSType);
624  RS_TYPE_SET_CLASS(ST, RS_TC_Primitive);
625  // enum RSExportPrimitiveType::DataType is synced with enum RSDataType in
626  // slang_rs_type_spec.h
627  RS_PRIMITIVE_TYPE_SET_DATA_TYPE(ST, getType());
628  return ST.take();
629}
630
631bool RSExportPrimitiveType::equals(const RSExportable *E) const {
632  CHECK_PARENT_EQUALITY(RSExportType, E);
633  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
634}
635
636/**************************** RSExportPointerType ****************************/
637
638const clang::Type *RSExportPointerType::IntegerType = NULL;
639
640RSExportPointerType
641*RSExportPointerType::Create(RSContext *Context,
642                             const clang::PointerType *PT,
643                             const llvm::StringRef &TypeName) {
644  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
645  const RSExportType *PointeeET;
646
647  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
648    PointeeET = RSExportType::Create(Context, PointeeType);
649  } else {
650    // Double or higher dimension of pointer, export as int*
651    assert(IntegerType != NULL && "Built-in integer type is not set");
652    PointeeET = RSExportPrimitiveType::Create(Context, IntegerType);
653  }
654
655  if (PointeeET == NULL) {
656    fprintf(stderr, "Failed to create type for pointee");
657    return NULL;
658  }
659
660  return new RSExportPointerType(Context, TypeName, PointeeET);
661}
662
663const llvm::Type *RSExportPointerType::convertToLLVMType() const {
664  const llvm::Type *PointeeType = mPointeeType->getLLVMType();
665  return llvm::PointerType::getUnqual(PointeeType);
666}
667
668union RSType *RSExportPointerType::convertToSpecType() const {
669  llvm::OwningPtr<union RSType> ST(new union RSType);
670
671  RS_TYPE_SET_CLASS(ST, RS_TC_Pointer);
672  RS_POINTER_TYPE_SET_POINTEE_TYPE(ST, getPointeeType()->getSpecType());
673
674  if (RS_POINTER_TYPE_GET_POINTEE_TYPE(ST) != NULL)
675    return ST.take();
676  else
677    return NULL;
678}
679
680bool RSExportPointerType::keep() {
681  if (!RSExportType::keep())
682    return false;
683  const_cast<RSExportType*>(mPointeeType)->keep();
684  return true;
685}
686
687bool RSExportPointerType::equals(const RSExportable *E) const {
688  CHECK_PARENT_EQUALITY(RSExportType, E);
689  return (static_cast<const RSExportPointerType*>(E)
690              ->getPointeeType()->equals(getPointeeType()));
691}
692
693/***************************** RSExportVectorType *****************************/
694llvm::StringRef
695RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
696  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
697
698  if ((ElementType->getTypeClass() != clang::Type::Builtin))
699    return llvm::StringRef();
700
701  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType,
702                                                  ElementType);
703  if ((EVT->getNumElements() < 1) ||
704      (EVT->getNumElements() > 4))
705    return llvm::StringRef();
706
707  switch (BT->getKind()) {
708    // Compiler is smart enough to optimize following *big if branches* since
709    // they all become "constant comparison" after macro expansion
710#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
711    case builtin_type: {                                      \
712      const char *Name[] = { cname"2", cname"3", cname"4" };  \
713      return Name[EVT->getNumElements() - 2];                 \
714      break;                                                  \
715    }
716#include "RSClangBuiltinEnums.inc"
717    default: {
718      return llvm::StringRef();
719    }
720  }
721}
722
723RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
724                                               const clang::ExtVectorType *EVT,
725                                               const llvm::StringRef &TypeName,
726                                               DataKind DK,
727                                               bool Normalized) {
728  assert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
729
730  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
731  RSExportPrimitiveType::DataType DT =
732      RSExportPrimitiveType::GetDataType(ElementType);
733
734  if (DT != RSExportPrimitiveType::DataTypeUnknown)
735    return new RSExportVectorType(Context,
736                                  TypeName,
737                                  DT,
738                                  DK,
739                                  Normalized,
740                                  EVT->getNumElements());
741  else
742    fprintf(stderr, "RSExportVectorType::Create : unsupported base element "
743                    "type\n");
744  return NULL;
745}
746
747const llvm::Type *RSExportVectorType::convertToLLVMType() const {
748  const llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
749  return llvm::VectorType::get(ElementType, getNumElement());
750}
751
752union RSType *RSExportVectorType::convertToSpecType() const {
753  llvm::OwningPtr<union RSType> ST(new union RSType);
754
755  RS_TYPE_SET_CLASS(ST, RS_TC_Vector);
756  RS_VECTOR_TYPE_SET_ELEMENT_TYPE(ST, getType());
757  RS_VECTOR_TYPE_SET_VECTOR_SIZE(ST, getNumElement());
758
759  return ST.take();
760}
761
762bool RSExportVectorType::equals(const RSExportable *E) const {
763  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
764  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
765              == getNumElement());
766}
767
768/***************************** RSExportMatrixType *****************************/
769RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
770                                               const clang::RecordType *RT,
771                                               const llvm::StringRef &TypeName,
772                                               unsigned Dim) {
773  assert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
774  assert((Dim > 1) && "Invalid dimension of matrix");
775
776  // Check whether the struct rs_matrix is in our expected form (but assume it's
777  // correct if we're not sure whether it's correct or not)
778  const clang::RecordDecl* RD = RT->getDecl();
779  RD = RD->getDefinition();
780  if (RD != NULL) {
781    // Find definition, perform further examination
782    if (RD->field_empty()) {
783      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
784                      "must have 1 field for saving values", TypeName.data());
785      return NULL;
786    }
787
788    clang::RecordDecl::field_iterator FIT = RD->field_begin();
789    const clang::FieldDecl *FD = *FIT;
790    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
791    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
792      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
793                      "first field should be an array with constant size",
794              TypeName.data());
795      return NULL;
796    }
797    const clang::ConstantArrayType *CAT =
798      static_cast<const clang::ConstantArrayType *>(FT);
799    const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
800    if ((ElementType == NULL) ||
801        (ElementType->getTypeClass() != clang::Type::Builtin) ||
802        (static_cast<const clang::BuiltinType *>(ElementType)->getKind()
803          != clang::BuiltinType::Float)) {
804      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
805                      "first field should be a float array", TypeName.data());
806      return NULL;
807    }
808
809    if (CAT->getSize() != Dim * Dim) {
810      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
811                      "first field should be an array with size %d",
812              TypeName.data(), Dim * Dim);
813      return NULL;
814    }
815
816    FIT++;
817    if (FIT != RD->field_end()) {
818      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
819                      "must have exactly 1 field", TypeName.data());
820      return NULL;
821    }
822  }
823
824  return new RSExportMatrixType(Context, TypeName, Dim);
825}
826
827const llvm::Type *RSExportMatrixType::convertToLLVMType() const {
828  // Construct LLVM type:
829  // struct {
830  //  float X[mDim * mDim];
831  // }
832
833  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
834  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
835                                            mDim * mDim);
836  return llvm::StructType::get(C, X, NULL);
837}
838
839union RSType *RSExportMatrixType::convertToSpecType() const {
840  llvm::OwningPtr<union RSType> ST(new union RSType);
841  RS_TYPE_SET_CLASS(ST, RS_TC_Matrix);
842  switch (getDim()) {
843    case 2: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix2x2); break;
844    case 3: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix3x3); break;
845    case 4: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix4x4); break;
846    default: assert(false && "Matrix type with unsupported dimension.");
847  }
848  return ST.take();
849}
850
851bool RSExportMatrixType::equals(const RSExportable *E) const {
852  CHECK_PARENT_EQUALITY(RSExportType, E);
853  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
854}
855
856/************************* RSExportConstantArrayType *************************/
857RSExportConstantArrayType
858*RSExportConstantArrayType::Create(RSContext *Context,
859                                   const clang::ConstantArrayType *CAT) {
860  assert(CAT != NULL && CAT->getTypeClass() == clang::Type::ConstantArray);
861
862  assert((CAT->getSize().getActiveBits() < 32) && "array too large");
863
864  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
865  assert((Size > 0) && "Constant array should have size greater than 0");
866
867  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
868  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
869
870  if (ElementET == NULL) {
871    fprintf(stderr, "RSExportConstantArrayType::Create : failed to create "
872                    "RSExportType for array element.\n");
873    return NULL;
874  }
875
876  return new RSExportConstantArrayType(Context,
877                                       ElementET,
878                                       Size);
879}
880
881const llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
882  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
883}
884
885union RSType *RSExportConstantArrayType::convertToSpecType() const {
886  llvm::OwningPtr<union RSType> ST(new union RSType);
887
888  RS_TYPE_SET_CLASS(ST, RS_TC_ConstantArray);
889  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_TYPE(
890      ST, getElementType()->getSpecType());
891  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_SIZE(ST, getSize());
892
893  if (RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(ST) != NULL)
894    return ST.take();
895  else
896    return NULL;
897}
898
899bool RSExportConstantArrayType::keep() {
900  if (!RSExportType::keep())
901    return false;
902  const_cast<RSExportType*>(mElementType)->keep();
903  return true;
904}
905
906bool RSExportConstantArrayType::equals(const RSExportable *E) const {
907  CHECK_PARENT_EQUALITY(RSExportType, E);
908  return ((static_cast<const RSExportConstantArrayType*>(E)
909              ->getSize() == getSize()) && (mElementType->equals(E)));
910}
911
912/**************************** RSExportRecordType ****************************/
913RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
914                                               const clang::RecordType *RT,
915                                               const llvm::StringRef &TypeName,
916                                               bool mIsArtificial) {
917  assert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
918
919  const clang::RecordDecl *RD = RT->getDecl();
920  assert(RD->isStruct());
921
922  RD = RD->getDefinition();
923  if (RD == NULL) {
924    // TODO(zonr): warn that actual struct definition isn't declared in this
925    //             moudle.
926    fprintf(stderr, "RSExportRecordType::Create : this struct is not defined "
927                    "in this module.");
928    return NULL;
929  }
930
931  // Struct layout construct by clang. We rely on this for obtaining the
932  // alloc size of a struct and offset of every field in that struct.
933  const clang::ASTRecordLayout *RL =
934      &Context->getASTContext().getASTRecordLayout(RD);
935  assert((RL != NULL) && "Failed to retrieve the struct layout from Clang.");
936
937  RSExportRecordType *ERT =
938      new RSExportRecordType(Context,
939                             TypeName,
940                             RD->hasAttr<clang::PackedAttr>(),
941                             mIsArtificial,
942                             (RL->getSize() >> 3));
943  unsigned int Index = 0;
944
945  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
946           FE = RD->field_end();
947       FI != FE;
948       FI++, Index++) {
949#define FAILED_CREATE_FIELD(err)    do {         \
950      if (*err)                                                          \
951        fprintf(stderr, \
952                "RSExportRecordType::Create : failed to create field (%s)\n", \
953                err);                                                   \
954      delete ERT;                                                       \
955      return NULL;                                                      \
956    } while (false)
957
958    // FIXME: All fields should be primitive type
959    assert((*FI)->getKind() == clang::Decl::Field);
960    clang::FieldDecl *FD = *FI;
961
962    // We don't support bit field
963    //
964    // TODO(zonr): allow bitfield with size 8, 16, 32
965    if (FD->isBitField())
966      FAILED_CREATE_FIELD("bit field is not supported");
967
968    // Type
969    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
970
971    if (ET != NULL)
972      ERT->mFields.push_back(
973          new Field(ET, FD->getName(), ERT,
974                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
975    else
976      FAILED_CREATE_FIELD(FD->getName().str().c_str());
977#undef FAILED_CREATE_FIELD
978  }
979
980  return ERT;
981}
982
983const llvm::Type *RSExportRecordType::convertToLLVMType() const {
984  // Create an opaque type since struct may reference itself recursively.
985  llvm::PATypeHolder ResultHolder =
986      llvm::OpaqueType::get(getRSContext()->getLLVMContext());
987  setAbstractLLVMType(ResultHolder.get());
988
989  std::vector<const llvm::Type*> FieldTypes;
990
991  for (const_field_iterator FI = fields_begin(), FE = fields_end();
992       FI != FE;
993       FI++) {
994    const Field *F = *FI;
995    const RSExportType *FET = F->getType();
996
997    FieldTypes.push_back(FET->getLLVMType());
998  }
999
1000  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1001                                               FieldTypes,
1002                                               mIsPacked);
1003  if (ST != NULL)
1004    static_cast<llvm::OpaqueType*>(ResultHolder.get())
1005        ->refineAbstractTypeTo(ST);
1006  else
1007    return NULL;
1008  return ResultHolder.get();
1009}
1010
1011union RSType *RSExportRecordType::convertToSpecType() const {
1012  unsigned NumFields = getFields().size();
1013  unsigned AllocSize = sizeof(union RSType) +
1014                       sizeof(struct RSRecordField) * NumFields;
1015  llvm::OwningPtr<union RSType> ST(
1016      reinterpret_cast<union RSType*>(operator new(AllocSize)));
1017
1018  ::memset(ST.get(), 0, AllocSize);
1019
1020  RS_TYPE_SET_CLASS(ST, RS_TC_Record);
1021  RS_RECORD_TYPE_SET_NAME(ST, getName().c_str());
1022  RS_RECORD_TYPE_SET_NUM_FIELDS(ST, NumFields);
1023
1024  setSpecTypeTemporarily(ST.get());
1025
1026  unsigned FieldIdx = 0;
1027  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1028       FI != FE;
1029       FI++, FieldIdx++) {
1030    const Field *F = *FI;
1031
1032    RS_RECORD_TYPE_SET_FIELD_NAME(ST, FieldIdx, F->getName().c_str());
1033    RS_RECORD_TYPE_SET_FIELD_TYPE(ST, FieldIdx, F->getType()->getSpecType());
1034
1035    enum RSDataKind DK = RS_DK_User;
1036    if ((F->getType()->getClass() == ExportClassPrimitive) ||
1037        (F->getType()->getClass() == ExportClassVector)) {
1038      const RSExportPrimitiveType *EPT =
1039        static_cast<const RSExportPrimitiveType*>(F->getType());
1040      // enum RSExportPrimitiveType::DataKind is synced with enum RSDataKind in
1041      // slang_rs_type_spec.h
1042      DK = static_cast<enum RSDataKind>(EPT->getKind());
1043    }
1044    RS_RECORD_TYPE_SET_FIELD_DATA_KIND(ST, FieldIdx, DK);
1045  }
1046
1047  // TODO(slang): Check whether all fields were created normally.
1048
1049  return ST.take();
1050}
1051
1052bool RSExportRecordType::keep() {
1053  if (!RSExportType::keep())
1054    return false;
1055  for (std::list<const Field*>::iterator I = mFields.begin(),
1056          E = mFields.end();
1057       I != E;
1058       I++) {
1059    const_cast<RSExportType*>((*I)->getType())->keep();
1060  }
1061  return true;
1062}
1063
1064bool RSExportRecordType::equals(const RSExportable *E) const {
1065  CHECK_PARENT_EQUALITY(RSExportType, E);
1066
1067  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1068
1069  if (ERT->getFields().size() != getFields().size())
1070    return false;
1071
1072  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1073
1074  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1075    if (!(*AI)->getType()->equals((*BI)->getType()))
1076      return false;
1077    AI++;
1078    BI++;
1079  }
1080
1081  return true;
1082}
1083
1084}  // namespace slang
1085