slang_rs_export_type.cpp revision b1771ef128b10c4d4575634828006bfba20b1d9c
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <vector>
20
21#include "llvm/Type.h"
22#include "llvm/DerivedTypes.h"
23
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/Target/TargetData.h"
26
27#include "clang/AST/RecordLayout.h"
28
29#include "slang_rs_context.h"
30#include "slang_rs_export_element.h"
31#include "slang_rs_type_spec.h"
32
33#define CHECK_PARENT_EQUALITY(ParentClass, E) \
34  if (!ParentClass::equals(E))                \
35    return false;
36
37using namespace slang;
38
39/****************************** RSExportType ******************************/
40bool RSExportType::NormalizeType(const clang::Type *&T,
41                                 llvm::StringRef &TypeName) {
42  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
43      llvm::SmallPtrSet<const clang::Type*, 8>();
44
45  if ((T = RSExportType::TypeExportable(T, SPS)) == NULL)
46    // TODO(zonr): warn that type not exportable.
47    return false;
48
49  // Get type name
50  TypeName = RSExportType::GetTypeName(T);
51  if (TypeName.empty())
52    // TODO(zonr): warning that the type is unnamed.
53    return false;
54
55  return true;
56}
57
58const clang::Type
59*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
60  if (DD) {
61    clang::QualType T;
62    if (DD->getTypeSourceInfo())
63      T = DD->getTypeSourceInfo()->getType();
64    else
65      T = DD->getType();
66
67    if (T.isNull())
68      return NULL;
69    else
70      return T.getTypePtr();
71  }
72  return NULL;
73}
74
75llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
76  T = GET_CANONICAL_TYPE(T);
77  if (T == NULL)
78    return llvm::StringRef();
79
80  switch (T->getTypeClass()) {
81    case clang::Type::Builtin: {
82      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
83
84      switch (BT->getKind()) {
85#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
86        case builtin_type:                                    \
87          return cname;                                       \
88        break;
89#include "RSClangBuiltinEnums.inc"
90        default: {
91          assert(false && "Unknown data type of the builtin");
92          break;
93        }
94      }
95      break;
96    }
97    case clang::Type::Record: {
98      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
99      llvm::StringRef Name = RD->getName();
100      if (Name.empty()) {
101          if (RD->getTypedefForAnonDecl() != NULL)
102            Name = RD->getTypedefForAnonDecl()->getName();
103
104          if (Name.empty())
105            // Try to find a name from redeclaration (i.e. typedef)
106            for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
107                     RE = RD->redecls_end();
108                 RI != RE;
109                 RI++) {
110              assert(*RI != NULL && "cannot be NULL object");
111
112              Name = (*RI)->getName();
113              if (!Name.empty())
114                break;
115            }
116      }
117      return Name;
118    }
119    case clang::Type::Pointer: {
120      // "*" plus pointee name
121      const clang::Type *PT = GET_POINTEE_TYPE(T);
122      llvm::StringRef PointeeName;
123      if (NormalizeType(PT, PointeeName)) {
124        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
125        Name[0] = '*';
126        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
127        Name[PointeeName.size() + 1] = '\0';
128        return Name;
129      }
130      break;
131    }
132    case clang::Type::ExtVector: {
133      const clang::ExtVectorType *EVT =
134          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
135      return RSExportVectorType::GetTypeName(EVT);
136      break;
137    }
138    case clang::Type::ConstantArray : {
139      // Construct name for a constant array is too complicated.
140      return DUMMY_TYPE_NAME_FOR_RS_CONSTANT_ARRAY_TYPE;
141    }
142    default: {
143      break;
144    }
145  }
146
147  return llvm::StringRef();
148}
149
150const clang::Type *RSExportType::TypeExportable(
151    const clang::Type *T,
152    llvm::SmallPtrSet<const clang::Type*, 8>& SPS) {
153  // Normalize first
154  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
155    return NULL;
156
157  if (SPS.count(T))
158    return T;
159
160  switch (T->getTypeClass()) {
161    case clang::Type::Builtin: {
162      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
163
164      switch (BT->getKind()) {
165#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
166        case builtin_type:
167#include "RSClangBuiltinEnums.inc"
168          return T;
169        default: {
170          return NULL;
171        }
172      }
173    }
174    case clang::Type::Record: {
175      if (RSExportPrimitiveType::GetRSSpecificType(T) !=
176          RSExportPrimitiveType::DataTypeUnknown)
177        return T;  // RS object type, no further checks are needed
178
179      // Check internal struct
180      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
181      if (RD != NULL)
182        RD = RD->getDefinition();
183
184      // Fast check
185      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
186        return NULL;
187
188      // Insert myself into checking set
189      SPS.insert(T);
190
191      // Check all element
192      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
193               FE = RD->field_end();
194           FI != FE;
195           FI++) {
196        const clang::FieldDecl *FD = *FI;
197        const clang::Type *FT = GetTypeOfDecl(FD);
198        FT = GET_CANONICAL_TYPE(FT);
199
200        if (!TypeExportable(FT, SPS)) {
201          fprintf(stderr, "Field `%s' in Record `%s' contains unsupported "
202                          "type\n", FD->getNameAsString().c_str(),
203                                    RD->getNameAsString().c_str());
204          FT->dump();
205          return NULL;
206        }
207      }
208
209      return T;
210    }
211    case clang::Type::Pointer: {
212      const clang::PointerType *PT = UNSAFE_CAST_TYPE(clang::PointerType, T);
213      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
214
215      if (PointeeType->getTypeClass() == clang::Type::Pointer)
216        return T;
217      // We don't support pointer with array-type pointee or unsupported pointee
218      // type
219      if (PointeeType->isArrayType() ||
220         (TypeExportable(PointeeType, SPS) == NULL) )
221        return NULL;
222      else
223        return T;
224    }
225    case clang::Type::ExtVector: {
226      const clang::ExtVectorType *EVT =
227          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
228      // Only vector with size 2, 3 and 4 are supported.
229      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
230        return NULL;
231
232      // Check base element type
233      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
234
235      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
236          (TypeExportable(ElementType, SPS) == NULL))
237        return NULL;
238      else
239        return T;
240    }
241    case clang::Type::ConstantArray: {
242      const clang::ConstantArrayType *CAT =
243          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
244
245      // Check size
246      if (CAT->getSize().getActiveBits() > 32) {
247        fprintf(stderr, "RSExportConstantArrayType::Create : array with too "
248                        "large size (> 2^32).\n");
249        return NULL;
250      }
251      // Check element type
252      const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
253      if (ElementType->isArrayType()) {
254        fprintf(stderr, "RSExportType::TypeExportable : constant array with 2 "
255                        "or higher dimension of constant is not supported.\n");
256        return NULL;
257      }
258      if (TypeExportable(ElementType, SPS) == NULL)
259        return NULL;
260      else
261        return T;
262    }
263    default: {
264      return NULL;
265    }
266  }
267}
268
269RSExportType *RSExportType::Create(RSContext *Context,
270                                   const clang::Type *T,
271                                   const llvm::StringRef &TypeName) {
272  // Lookup the context to see whether the type was processed before.
273  // Newly created RSExportType will insert into context
274  // in RSExportType::RSExportType()
275  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
276
277  if (ETI != Context->export_types_end())
278    return ETI->second;
279
280  RSExportType *ET = NULL;
281  switch (T->getTypeClass()) {
282    case clang::Type::Record: {
283      RSExportPrimitiveType::DataType dt =
284          RSExportPrimitiveType::GetRSSpecificType(TypeName);
285      switch (dt) {
286        case RSExportPrimitiveType::DataTypeUnknown: {
287          // User-defined types
288          ET = RSExportRecordType::Create(Context,
289                                          T->getAsStructureType(),
290                                          TypeName);
291          break;
292        }
293        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
294          // 2 x 2 Matrix type
295          ET = RSExportMatrixType::Create(Context,
296                                          T->getAsStructureType(),
297                                          TypeName,
298                                          2);
299          break;
300        }
301        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
302          // 3 x 3 Matrix type
303          ET = RSExportMatrixType::Create(Context,
304                                          T->getAsStructureType(),
305                                          TypeName,
306                                          3);
307          break;
308        }
309        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
310          // 4 x 4 Matrix type
311          ET = RSExportMatrixType::Create(Context,
312                                          T->getAsStructureType(),
313                                          TypeName,
314                                          4);
315          break;
316        }
317        default: {
318          // Others are primitive types
319          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
320          break;
321        }
322      }
323      break;
324    }
325    case clang::Type::Builtin: {
326      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
327      break;
328    }
329    case clang::Type::Pointer: {
330      ET = RSExportPointerType::Create(Context,
331                                       UNSAFE_CAST_TYPE(clang::PointerType, T),
332                                       TypeName);
333      // FIXME: free the name (allocated in RSExportType::GetTypeName)
334      delete [] TypeName.data();
335      break;
336    }
337    case clang::Type::ExtVector: {
338      ET = RSExportVectorType::Create(Context,
339                                      UNSAFE_CAST_TYPE(clang::ExtVectorType, T),
340                                      TypeName);
341      break;
342    }
343    case clang::Type::ConstantArray: {
344      ET = RSExportConstantArrayType::Create(
345              Context,
346              UNSAFE_CAST_TYPE(clang::ConstantArrayType, T));
347      break;
348    }
349    default: {
350      // TODO(zonr): warn that type is not exportable.
351      fprintf(stderr,
352              "RSExportType::Create : type '%s' is not exportable\n",
353              T->getTypeClassName());
354      break;
355    }
356  }
357
358  return ET;
359}
360
361RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
362  llvm::StringRef TypeName;
363  if (NormalizeType(T, TypeName))
364    return Create(Context, T, TypeName);
365  else
366    return NULL;
367}
368
369RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
370                                           const clang::VarDecl *VD) {
371  return RSExportType::Create(Context, GetTypeOfDecl(VD));
372}
373
374size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
375  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
376      ET->getLLVMType());
377}
378
379size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
380  if (ET->getClass() == RSExportType::ExportClassRecord)
381    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
382  else
383    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
384        ET->getLLVMType());
385}
386
387RSExportType::RSExportType(RSContext *Context,
388                           ExportClass Class,
389                           const llvm::StringRef &Name)
390    : RSExportable(Context, RSExportable::EX_TYPE),
391      mClass(Class),
392      // Make a copy on Name since memory stored @Name is either allocated in
393      // ASTContext or allocated in GetTypeName which will be destroyed later.
394      mName(Name.data(), Name.size()),
395      mLLVMType(NULL),
396      mSpecType(NULL) {
397  // Don't cache the type whose name start with '<'. Those type failed to
398  // get their name since constructing their name in GetTypeName() requiring
399  // complicated work.
400  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
401    // TODO(zonr): Need to check whether the insertion is successful or not.
402    Context->insertExportType(llvm::StringRef(Name), this);
403  return;
404}
405
406bool RSExportType::keep() {
407  if (!RSExportable::keep())
408    return false;
409  // Invalidate converted LLVM type.
410  mLLVMType = NULL;
411  return true;
412}
413
414bool RSExportType::equals(const RSExportable *E) const {
415  CHECK_PARENT_EQUALITY(RSExportable, E);
416  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
417}
418
419RSExportType::~RSExportType() {
420  delete mSpecType;
421}
422
423/************************** RSExportPrimitiveType **************************/
424llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
425RSExportPrimitiveType::RSSpecificTypeMap;
426
427llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
428
429bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
430  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
431    return true;
432  else
433    return false;
434}
435
436RSExportPrimitiveType::DataType
437RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
438  if (TypeName.empty())
439    return DataTypeUnknown;
440
441  if (RSSpecificTypeMap->empty()) {
442#define ENUM_RS_MATRIX_TYPE(type, cname, dim)                       \
443    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
444#include "RSMatrixTypeEnums.inc"
445#define ENUM_RS_OBJECT_TYPE(type, cname)                            \
446    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
447#include "RSObjectTypeEnums.inc"
448  }
449
450  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
451  if (I == RSSpecificTypeMap->end())
452    return DataTypeUnknown;
453  else
454    return I->getValue();
455}
456
457RSExportPrimitiveType::DataType
458RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
459  T = GET_CANONICAL_TYPE(T);
460  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
461    return DataTypeUnknown;
462
463  return GetRSSpecificType( RSExportType::GetTypeName(T) );
464}
465
466bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
467  return ((DT >= FirstRSMatrixType) && (DT <= LastRSMatrixType));
468}
469
470bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
471  return ((DT >= FirstRSObjectType) && (DT <= LastRSObjectType));
472}
473
474const size_t RSExportPrimitiveType::SizeOfDataTypeInBits[] = {
475#define ENUM_RS_DATA_TYPE(type, cname, bits)  \
476  bits,
477#include "RSDataTypeEnums.inc"
478  0   // DataTypeMax
479};
480
481size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
482  assert(((EPT->getType() > DataTypeUnknown) &&
483          (EPT->getType() < DataTypeMax)) &&
484         "RSExportPrimitiveType::GetSizeInBits : unknown data type");
485  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
486}
487
488RSExportPrimitiveType::DataType
489RSExportPrimitiveType::GetDataType(const clang::Type *T) {
490  if (T == NULL)
491    return DataTypeUnknown;
492
493  switch (T->getTypeClass()) {
494    case clang::Type::Builtin: {
495      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
496      switch (BT->getKind()) {
497#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
498        case builtin_type: {                                  \
499          return DataType ## type;                            \
500        }
501#include "RSClangBuiltinEnums.inc"
502        // The size of type WChar depend on platform so we abandon the support
503        // to them.
504        default: {
505          fprintf(stderr, "RSExportPrimitiveType::GetDataType : unsupported "
506                          "built-in type '%s'\n.", T->getTypeClassName());
507          break;
508        }
509      }
510      break;
511    }
512    case clang::Type::Record: {
513      // must be RS object type
514      return RSExportPrimitiveType::GetRSSpecificType(T);
515    }
516    default: {
517      fprintf(stderr, "RSExportPrimitiveType::GetDataType : type '%s' is not "
518                      "supported primitive type\n", T->getTypeClassName());
519      break;
520    }
521  }
522
523  return DataTypeUnknown;
524}
525
526RSExportPrimitiveType
527*RSExportPrimitiveType::Create(RSContext *Context,
528                               const clang::Type *T,
529                               const llvm::StringRef &TypeName,
530                               DataKind DK,
531                               bool Normalized) {
532  DataType DT = GetDataType(T);
533
534  if ((DT == DataTypeUnknown) || TypeName.empty())
535    return NULL;
536  else
537    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
538                                     DT, DK, Normalized);
539}
540
541RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
542                                                     const clang::Type *T,
543                                                     DataKind DK) {
544  llvm::StringRef TypeName;
545  if (RSExportType::NormalizeType(T, TypeName) && IsPrimitiveType(T))
546    return Create(Context, T, TypeName, DK);
547  else
548    return NULL;
549}
550
551const llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
552  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
553
554  if (isRSObjectType()) {
555    // struct {
556    //   int *p;
557    // } __attribute__((packed, aligned(pointer_size)))
558    //
559    // which is
560    //
561    // <{ [1 x i32] }> in LLVM
562    //
563    if (RSObjectLLVMType == NULL) {
564      std::vector<const llvm::Type *> Elements;
565      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
566      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
567    }
568    return RSObjectLLVMType;
569  }
570
571  switch (mType) {
572    case DataTypeFloat32: {
573      return llvm::Type::getFloatTy(C);
574      break;
575    }
576    case DataTypeFloat64: {
577      return llvm::Type::getDoubleTy(C);
578      break;
579    }
580    case DataTypeBoolean: {
581      return llvm::Type::getInt1Ty(C);
582      break;
583    }
584    case DataTypeSigned8:
585    case DataTypeUnsigned8: {
586      return llvm::Type::getInt8Ty(C);
587      break;
588    }
589    case DataTypeSigned16:
590    case DataTypeUnsigned16:
591    case DataTypeUnsigned565:
592    case DataTypeUnsigned5551:
593    case DataTypeUnsigned4444: {
594      return llvm::Type::getInt16Ty(C);
595      break;
596    }
597    case DataTypeSigned32:
598    case DataTypeUnsigned32: {
599      return llvm::Type::getInt32Ty(C);
600      break;
601    }
602    case DataTypeSigned64:
603    case DataTypeUnsigned64: {
604      return llvm::Type::getInt64Ty(C);
605      break;
606    }
607    default: {
608      assert(false && "Unknown data type");
609    }
610  }
611
612  return NULL;
613}
614
615union RSType *RSExportPrimitiveType::convertToSpecType() const {
616  llvm::OwningPtr<union RSType> ST(new union RSType);
617  RS_TYPE_SET_CLASS(ST, RS_TC_Primitive);
618  // enum RSExportPrimitiveType::DataType is synced with enum RSDataType in
619  // slang_rs_type_spec.h
620  RS_PRIMITIVE_TYPE_SET_DATA_TYPE(ST, getType());
621  return ST.take();
622}
623
624bool RSExportPrimitiveType::equals(const RSExportable *E) const {
625  CHECK_PARENT_EQUALITY(RSExportType, E);
626  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
627}
628
629/**************************** RSExportPointerType ****************************/
630
631const clang::Type *RSExportPointerType::IntegerType = NULL;
632
633RSExportPointerType
634*RSExportPointerType::Create(RSContext *Context,
635                             const clang::PointerType *PT,
636                             const llvm::StringRef &TypeName) {
637  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
638  const RSExportType *PointeeET;
639
640  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
641    PointeeET = RSExportType::Create(Context, PointeeType);
642  } else {
643    // Double or higher dimension of pointer, export as int*
644    assert(IntegerType != NULL && "Built-in integer type is not set");
645    PointeeET = RSExportPrimitiveType::Create(Context, IntegerType);
646  }
647
648  if (PointeeET == NULL) {
649    fprintf(stderr, "Failed to create type for pointee");
650    return NULL;
651  }
652
653  return new RSExportPointerType(Context, TypeName, PointeeET);
654}
655
656const llvm::Type *RSExportPointerType::convertToLLVMType() const {
657  const llvm::Type *PointeeType = mPointeeType->getLLVMType();
658  return llvm::PointerType::getUnqual(PointeeType);
659}
660
661union RSType *RSExportPointerType::convertToSpecType() const {
662  llvm::OwningPtr<union RSType> ST(new union RSType);
663
664  RS_TYPE_SET_CLASS(ST, RS_TC_Pointer);
665  RS_POINTER_TYPE_SET_POINTEE_TYPE(ST, getPointeeType()->getSpecType());
666
667  if (RS_POINTER_TYPE_GET_POINTEE_TYPE(ST) != NULL)
668    return ST.take();
669  else
670    return NULL;
671}
672
673bool RSExportPointerType::keep() {
674  if (!RSExportType::keep())
675    return false;
676  const_cast<RSExportType*>(mPointeeType)->keep();
677  return true;
678}
679
680bool RSExportPointerType::equals(const RSExportable *E) const {
681  CHECK_PARENT_EQUALITY(RSExportType, E);
682  return (static_cast<const RSExportPointerType*>(E)
683              ->getPointeeType()->equals(getPointeeType()));
684}
685
686/***************************** RSExportVectorType *****************************/
687llvm::StringRef
688RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
689  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
690
691  if ((ElementType->getTypeClass() != clang::Type::Builtin))
692    return llvm::StringRef();
693
694  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType,
695                                                  ElementType);
696  if ((EVT->getNumElements() < 1) ||
697      (EVT->getNumElements() > 4))
698    return llvm::StringRef();
699
700  switch (BT->getKind()) {
701    // Compiler is smart enough to optimize following *big if branches* since
702    // they all become "constant comparison" after macro expansion
703#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
704    case builtin_type: {                                      \
705      const char *Name[] = { cname"2", cname"3", cname"4" };  \
706      return Name[EVT->getNumElements() - 2];                 \
707      break;                                                  \
708    }
709#include "RSClangBuiltinEnums.inc"
710    default: {
711      return llvm::StringRef();
712    }
713  }
714}
715
716RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
717                                               const clang::ExtVectorType *EVT,
718                                               const llvm::StringRef &TypeName,
719                                               DataKind DK,
720                                               bool Normalized) {
721  assert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
722
723  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
724  RSExportPrimitiveType::DataType DT =
725      RSExportPrimitiveType::GetDataType(ElementType);
726
727  if (DT != RSExportPrimitiveType::DataTypeUnknown)
728    return new RSExportVectorType(Context,
729                                  TypeName,
730                                  DT,
731                                  DK,
732                                  Normalized,
733                                  EVT->getNumElements());
734  else
735    fprintf(stderr, "RSExportVectorType::Create : unsupported base element "
736                    "type\n");
737  return NULL;
738}
739
740const llvm::Type *RSExportVectorType::convertToLLVMType() const {
741  const llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
742  return llvm::VectorType::get(ElementType, getNumElement());
743}
744
745union RSType *RSExportVectorType::convertToSpecType() const {
746  llvm::OwningPtr<union RSType> ST(new union RSType);
747
748  RS_TYPE_SET_CLASS(ST, RS_TC_Vector);
749  RS_VECTOR_TYPE_SET_ELEMENT_TYPE(ST, getType());
750  RS_VECTOR_TYPE_SET_VECTOR_SIZE(ST, getNumElement());
751
752  return ST.take();
753}
754
755bool RSExportVectorType::equals(const RSExportable *E) const {
756  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
757  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
758              == getNumElement());
759}
760
761/***************************** RSExportMatrixType *****************************/
762RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
763                                               const clang::RecordType *RT,
764                                               const llvm::StringRef &TypeName,
765                                               unsigned Dim) {
766  assert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
767  assert((Dim > 1) && "Invalid dimension of matrix");
768
769  // Check whether the struct rs_matrix is in our expected form (but assume it's
770  // correct if we're not sure whether it's correct or not)
771  const clang::RecordDecl* RD = RT->getDecl();
772  RD = RD->getDefinition();
773  if (RD != NULL) {
774    // Find definition, perform further examination
775    if (RD->field_empty()) {
776      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
777                      "must have 1 field for saving values", TypeName.data());
778      return NULL;
779    }
780
781    clang::RecordDecl::field_iterator FIT = RD->field_begin();
782    const clang::FieldDecl *FD = *FIT;
783    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
784    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
785      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
786                      "first field should be an array with constant size",
787              TypeName.data());
788      return NULL;
789    }
790    const clang::ConstantArrayType *CAT =
791      static_cast<const clang::ConstantArrayType *>(FT);
792    const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
793    if ((ElementType == NULL) ||
794        (ElementType->getTypeClass() != clang::Type::Builtin) ||
795        (static_cast<const clang::BuiltinType *>(ElementType)->getKind()
796          != clang::BuiltinType::Float)) {
797      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
798                      "first field should be a float array", TypeName.data());
799      return NULL;
800    }
801
802    if (CAT->getSize() != Dim * Dim) {
803      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
804                      "first field should be an array with size %d",
805              TypeName.data(), Dim * Dim);
806      return NULL;
807    }
808
809    FIT++;
810    if (FIT != RD->field_end()) {
811      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
812                      "must have exactly 1 field", TypeName.data());
813      return NULL;
814    }
815  }
816
817  return new RSExportMatrixType(Context, TypeName, Dim);
818}
819
820const llvm::Type *RSExportMatrixType::convertToLLVMType() const {
821  // Construct LLVM type:
822  // struct {
823  //  float X[mDim * mDim];
824  // }
825
826  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
827  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
828                                            mDim * mDim);
829  return llvm::StructType::get(C, X, NULL);
830}
831
832union RSType *RSExportMatrixType::convertToSpecType() const {
833  llvm::OwningPtr<union RSType> ST(new union RSType);
834  RS_TYPE_SET_CLASS(ST, RS_TC_Matrix);
835  switch (getDim()) {
836    case 2: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix2x2); break;
837    case 3: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix3x3); break;
838    case 4: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix4x4); break;
839    default: assert(false && "Matrix type with unsupported dimension.");
840  }
841  return ST.take();
842}
843
844bool RSExportMatrixType::equals(const RSExportable *E) const {
845  CHECK_PARENT_EQUALITY(RSExportType, E);
846  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
847}
848
849/************************* RSExportConstantArrayType *************************/
850RSExportConstantArrayType
851*RSExportConstantArrayType::Create(RSContext *Context,
852                                   const clang::ConstantArrayType *CAT) {
853  assert(CAT != NULL && CAT->getTypeClass() == clang::Type::ConstantArray);
854
855  assert((CAT->getSize().getActiveBits() < 32) && "array too large");
856
857  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
858  assert((Size > 0) && "Constant array should have size greater than 0");
859
860  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
861  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
862
863  if (ElementET == NULL) {
864    fprintf(stderr, "RSExportConstantArrayType::Create : failed to create "
865                    "RSExportType for array element.\n");
866    return NULL;
867  }
868
869  return new RSExportConstantArrayType(Context,
870                                       ElementET,
871                                       Size);
872}
873
874const llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
875  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
876}
877
878union RSType *RSExportConstantArrayType::convertToSpecType() const {
879  llvm::OwningPtr<union RSType> ST(new union RSType);
880
881  RS_TYPE_SET_CLASS(ST, RS_TC_ConstantArray);
882  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_TYPE(
883      ST, getElementType()->getSpecType());
884  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_SIZE(ST, getSize());
885
886  if (RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(ST) != NULL)
887    return ST.take();
888  else
889    return NULL;
890}
891
892bool RSExportConstantArrayType::keep() {
893  if (!RSExportType::keep())
894    return false;
895  const_cast<RSExportType*>(mElementType)->keep();
896  return true;
897}
898
899bool RSExportConstantArrayType::equals(const RSExportable *E) const {
900  CHECK_PARENT_EQUALITY(RSExportType, E);
901  return ((static_cast<const RSExportConstantArrayType*>(E)
902              ->getSize() == getSize()) && (mElementType->equals(E)));
903}
904
905/**************************** RSExportRecordType ****************************/
906RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
907                                               const clang::RecordType *RT,
908                                               const llvm::StringRef &TypeName,
909                                               bool mIsArtificial) {
910  assert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
911
912  const clang::RecordDecl *RD = RT->getDecl();
913  assert(RD->isStruct());
914
915  RD = RD->getDefinition();
916  if (RD == NULL) {
917    // TODO(zonr): warn that actual struct definition isn't declared in this
918    //             moudle.
919    fprintf(stderr, "RSExportRecordType::Create : this struct is not defined "
920                    "in this module.");
921    return NULL;
922  }
923
924  // Struct layout construct by clang. We rely on this for obtaining the
925  // alloc size of a struct and offset of every field in that struct.
926  const clang::ASTRecordLayout *RL =
927      &Context->getASTContext()->getASTRecordLayout(RD);
928  assert((RL != NULL) && "Failed to retrieve the struct layout from Clang.");
929
930  RSExportRecordType *ERT =
931      new RSExportRecordType(Context,
932                             TypeName,
933                             RD->hasAttr<clang::PackedAttr>(),
934                             mIsArtificial,
935                             (RL->getSize() >> 3));
936  unsigned int Index = 0;
937
938  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
939           FE = RD->field_end();
940       FI != FE;
941       FI++, Index++) {
942#define FAILED_CREATE_FIELD(err)    do {         \
943      if (*err)                                                          \
944        fprintf(stderr, \
945                "RSExportRecordType::Create : failed to create field (%s)\n", \
946                err);                                                   \
947      delete ERT;                                                       \
948      return NULL;                                                      \
949    } while (false)
950
951    // FIXME: All fields should be primitive type
952    assert((*FI)->getKind() == clang::Decl::Field);
953    clang::FieldDecl *FD = *FI;
954
955    // We don't support bit field
956    //
957    // TODO(zonr): allow bitfield with size 8, 16, 32
958    if (FD->isBitField())
959      FAILED_CREATE_FIELD("bit field is not supported");
960
961    // Type
962    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
963
964    if (ET != NULL)
965      ERT->mFields.push_back(
966          new Field(ET, FD->getName(), ERT,
967                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
968    else
969      FAILED_CREATE_FIELD(FD->getName().str().c_str());
970#undef FAILED_CREATE_FIELD
971  }
972
973  return ERT;
974}
975
976const llvm::Type *RSExportRecordType::convertToLLVMType() const {
977  // Create an opaque type since struct may reference itself recursively.
978  llvm::PATypeHolder ResultHolder =
979      llvm::OpaqueType::get(getRSContext()->getLLVMContext());
980  setAbstractLLVMType(ResultHolder.get());
981
982  std::vector<const llvm::Type*> FieldTypes;
983
984  for (const_field_iterator FI = fields_begin(), FE = fields_end();
985       FI != FE;
986       FI++) {
987    const Field *F = *FI;
988    const RSExportType *FET = F->getType();
989
990    FieldTypes.push_back(FET->getLLVMType());
991  }
992
993  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
994                                               FieldTypes,
995                                               mIsPacked);
996  if (ST != NULL)
997    static_cast<llvm::OpaqueType*>(ResultHolder.get())
998        ->refineAbstractTypeTo(ST);
999  else
1000    return NULL;
1001  return ResultHolder.get();
1002}
1003
1004union RSType *RSExportRecordType::convertToSpecType() const {
1005  unsigned NumFields = getFields().size();
1006  unsigned AllocSize = sizeof(union RSType) +
1007                       sizeof(struct RSRecordField) * NumFields;
1008  llvm::OwningPtr<union RSType> ST(
1009      reinterpret_cast<union RSType*>(operator new (AllocSize)));
1010
1011  ::memset(ST.get(), 0, AllocSize);
1012
1013  RS_TYPE_SET_CLASS(ST, RS_TC_Record);
1014  RS_RECORD_TYPE_SET_NAME(ST, getName().c_str());
1015  RS_RECORD_TYPE_SET_NUM_FIELDS(ST, NumFields);
1016
1017  setSpecTypeTemporarily(ST.get());
1018
1019  unsigned FieldIdx = 0;
1020  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1021       FI != FE;
1022       FI++, FieldIdx++) {
1023    const Field *F = *FI;
1024
1025    RS_RECORD_TYPE_SET_FIELD_NAME(ST, FieldIdx, F->getName().c_str());
1026    RS_RECORD_TYPE_SET_FIELD_TYPE(ST, FieldIdx, F->getType()->getSpecType());
1027
1028    enum RSDataKind DK = RS_DK_User;
1029    if ((F->getType()->getClass() == ExportClassPrimitive) ||
1030        (F->getType()->getClass() == ExportClassVector)) {
1031      const RSExportPrimitiveType *EPT =
1032        static_cast<const RSExportPrimitiveType*>(F->getType());
1033      // enum RSExportPrimitiveType::DataKind is synced with enum RSDataKind in
1034      // slang_rs_type_spec.h
1035      DK = static_cast<enum RSDataKind>(EPT->getKind());
1036    }
1037    RS_RECORD_TYPE_SET_FIELD_DATA_KIND(ST, FieldIdx, DK);
1038  }
1039
1040  // TODO: Check whether all fields were created normaly.
1041
1042  return ST.take();
1043}
1044
1045bool RSExportRecordType::keep() {
1046  if (!RSExportType::keep())
1047    return false;
1048  for (std::list<const Field*>::iterator I = mFields.begin(),
1049          E = mFields.end();
1050       I != E;
1051       I++) {
1052    const_cast<RSExportType*>((*I)->getType())->keep();
1053  }
1054  return true;
1055}
1056
1057bool RSExportRecordType::equals(const RSExportable *E) const {
1058  CHECK_PARENT_EQUALITY(RSExportType, E);
1059
1060  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1061
1062  if (ERT->getFields().size() != getFields().size())
1063    return false;
1064
1065  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1066
1067  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1068    if (!(*AI)->getType()->equals((*BI)->getType()))
1069      return false;
1070    AI++;
1071    BI++;
1072  }
1073
1074  return true;
1075}
1076