slang_rs_export_type.cpp revision 7363d8430db732c42d392fcab47cf0e3f8eb4515
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <vector>
20
21#include "llvm/Type.h"
22#include "llvm/DerivedTypes.h"
23
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/Target/TargetData.h"
26
27#include "clang/AST/RecordLayout.h"
28
29#include "slang_rs_context.h"
30#include "slang_rs_export_element.h"
31
32#define CHECK_PARENT_EQUALITY(ParentClass, E) \
33  if (!ParentClass::equals(E))                \
34    return false;
35
36using namespace slang;
37
38/****************************** RSExportType ******************************/
39bool RSExportType::NormalizeType(const clang::Type *&T,
40                                 llvm::StringRef &TypeName) {
41  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
42      llvm::SmallPtrSet<const clang::Type*, 8>();
43
44  if ((T = RSExportType::TypeExportable(T, SPS)) == NULL)
45    // TODO(zonr): warn that type not exportable.
46    return false;
47
48  // Get type name
49  TypeName = RSExportType::GetTypeName(T);
50  if (TypeName.empty())
51    // TODO(zonr): warning that the type is unnamed.
52    return false;
53
54  return true;
55}
56
57const clang::Type
58*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
59  if (DD) {
60    clang::QualType T;
61    if (DD->getTypeSourceInfo())
62      T = DD->getTypeSourceInfo()->getType();
63    else
64      T = DD->getType();
65
66    if (T.isNull())
67      return NULL;
68    else
69      return T.getTypePtr();
70  }
71  return NULL;
72}
73
74llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
75  T = GET_CANONICAL_TYPE(T);
76  if (T == NULL)
77    return llvm::StringRef();
78
79  switch (T->getTypeClass()) {
80    case clang::Type::Builtin: {
81      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
82
83      switch (BT->getKind()) {
84#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
85        case builtin_type:                                    \
86          return cname;                                       \
87        break;
88#include "RSClangBuiltinEnums.inc"
89#undef ENUM_SUPPORT_BUILTIN_TYPE
90        default: {
91          assert(false && "Unknown data type of the builtin");
92          break;
93        }
94      }
95      break;
96    }
97    case clang::Type::Record: {
98      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
99      llvm::StringRef Name = RD->getName();
100      if (Name.empty()) {
101          if (RD->getTypedefForAnonDecl() != NULL)
102            Name = RD->getTypedefForAnonDecl()->getName();
103
104          if (Name.empty())
105            // Try to find a name from redeclaration (i.e. typedef)
106            for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
107                     RE = RD->redecls_end();
108                 RI != RE;
109                 RI++) {
110              assert(*RI != NULL && "cannot be NULL object");
111
112              Name = (*RI)->getName();
113              if (!Name.empty())
114                break;
115            }
116      }
117      return Name;
118    }
119    case clang::Type::Pointer: {
120      // "*" plus pointee name
121      const clang::Type *PT = GET_POINTEE_TYPE(T);
122      llvm::StringRef PointeeName;
123      if (NormalizeType(PT, PointeeName)) {
124        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
125        Name[0] = '*';
126        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
127        Name[PointeeName.size() + 1] = '\0';
128        return Name;
129      }
130      break;
131    }
132    case clang::Type::ExtVector: {
133      const clang::ExtVectorType *EVT =
134          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
135      return RSExportVectorType::GetTypeName(EVT);
136      break;
137    }
138    case clang::Type::ConstantArray : {
139      // Construct name for a constant array is too complicated.
140      return DUMMY_TYPE_NAME_FOR_RS_CONSTANT_ARRAY_TYPE;
141    }
142    default: {
143      break;
144    }
145  }
146
147  return llvm::StringRef();
148}
149
150const clang::Type *RSExportType::TypeExportable(
151    const clang::Type *T,
152    llvm::SmallPtrSet<const clang::Type*, 8>& SPS) {
153  // Normalize first
154  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
155    return NULL;
156
157  if (SPS.count(T))
158    return T;
159
160  switch (T->getTypeClass()) {
161    case clang::Type::Builtin: {
162      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
163
164      switch (BT->getKind()) {
165#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
166        case builtin_type:
167#include "RSClangBuiltinEnums.inc"
168#undef ENUM_SUPPORT_BUILTIN_TYPE
169          return T;
170        default: {
171          return NULL;
172        }
173      }
174    }
175    case clang::Type::Record: {
176      if (RSExportPrimitiveType::GetRSObjectType(T) !=
177          RSExportPrimitiveType::DataTypeUnknown)
178        return T;  // RS object type, no further checks are needed
179
180      // Check internal struct
181      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
182      if (RD != NULL)
183        RD = RD->getDefinition();
184
185      // Fast check
186      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
187        return NULL;
188
189      // Insert myself into checking set
190      SPS.insert(T);
191
192      // Check all element
193      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
194               FE = RD->field_end();
195           FI != FE;
196           FI++) {
197        const clang::FieldDecl *FD = *FI;
198        const clang::Type *FT = GetTypeOfDecl(FD);
199        FT = GET_CANONICAL_TYPE(FT);
200
201        if (!TypeExportable(FT, SPS)) {
202          fprintf(stderr, "Field `%s' in Record `%s' contains unsupported "
203                          "type\n", FD->getNameAsString().c_str(),
204                                    RD->getNameAsString().c_str());
205          FT->dump();
206          return NULL;
207        }
208      }
209
210      return T;
211    }
212    case clang::Type::Pointer: {
213      const clang::PointerType *PT = UNSAFE_CAST_TYPE(clang::PointerType, T);
214      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
215
216      if (PointeeType->getTypeClass() == clang::Type::Pointer)
217        return T;
218      // We don't support pointer with array-type pointee or unsupported pointee
219      // type
220      if (PointeeType->isArrayType() ||
221         (TypeExportable(PointeeType, SPS) == NULL) )
222        return NULL;
223      else
224        return T;
225    }
226    case clang::Type::ExtVector: {
227      const clang::ExtVectorType *EVT =
228          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
229      // Only vector with size 2, 3 and 4 are supported.
230      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
231        return NULL;
232
233      // Check base element type
234      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
235
236      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
237          (TypeExportable(ElementType, SPS) == NULL))
238        return NULL;
239      else
240        return T;
241    }
242    case clang::Type::ConstantArray: {
243      const clang::ConstantArrayType *CAT =
244          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
245
246      // Check size
247      if (CAT->getSize().getActiveBits() > 32) {
248        fprintf(stderr, "RSExportConstantArrayType::Create : array with too "
249                        "large size (> 2^32).\n");
250        return NULL;
251      }
252      // Check element type
253      const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
254      if (ElementType->isArrayType()) {
255        fprintf(stderr, "RSExportType::TypeExportable : constant array with 2 "
256                        "or higher dimension of constant is not supported.\n");
257        return NULL;
258      }
259      if (TypeExportable(ElementType, SPS) == NULL)
260        return NULL;
261      else
262        return T;
263    }
264    default: {
265      return NULL;
266    }
267  }
268}
269
270RSExportType *RSExportType::Create(RSContext *Context,
271                                   const clang::Type *T,
272                                   const llvm::StringRef &TypeName) {
273  // Lookup the context to see whether the type was processed before.
274  // Newly created RSExportType will insert into context
275  // in RSExportType::RSExportType()
276  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
277
278  if (ETI != Context->export_types_end())
279    return ETI->second;
280
281  RSExportType *ET = NULL;
282  switch (T->getTypeClass()) {
283    case clang::Type::Record: {
284      RSExportPrimitiveType::DataType dt =
285          RSExportPrimitiveType::GetRSObjectType(TypeName);
286      switch (dt) {
287        case RSExportPrimitiveType::DataTypeUnknown: {
288          // User-defined types
289          ET = RSExportRecordType::Create(Context,
290                                          T->getAsStructureType(),
291                                          TypeName);
292          break;
293        }
294        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
295          // 2 x 2 Matrix type
296          ET = RSExportMatrixType::Create(Context,
297                                          T->getAsStructureType(),
298                                          TypeName,
299                                          2);
300          break;
301        }
302        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
303          // 3 x 3 Matrix type
304          ET = RSExportMatrixType::Create(Context,
305                                          T->getAsStructureType(),
306                                          TypeName,
307                                          3);
308          break;
309        }
310        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
311          // 4 x 4 Matrix type
312          ET = RSExportMatrixType::Create(Context,
313                                          T->getAsStructureType(),
314                                          TypeName,
315                                          4);
316          break;
317        }
318        default: {
319          // Others are primitive types
320          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
321          break;
322        }
323      }
324      break;
325    }
326    case clang::Type::Builtin: {
327      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
328      break;
329    }
330    case clang::Type::Pointer: {
331      ET = RSExportPointerType::Create(Context,
332                                       UNSAFE_CAST_TYPE(clang::PointerType, T),
333                                       TypeName);
334      // FIXME: free the name (allocated in RSExportType::GetTypeName)
335      delete [] TypeName.data();
336      break;
337    }
338    case clang::Type::ExtVector: {
339      ET = RSExportVectorType::Create(Context,
340                                      UNSAFE_CAST_TYPE(clang::ExtVectorType, T),
341                                      TypeName);
342      break;
343    }
344    case clang::Type::ConstantArray: {
345      ET = RSExportConstantArrayType::Create(
346              Context,
347              UNSAFE_CAST_TYPE(clang::ConstantArrayType, T));
348      break;
349    }
350    default: {
351      // TODO(zonr): warn that type is not exportable.
352      fprintf(stderr,
353              "RSExportType::Create : type '%s' is not exportable\n",
354              T->getTypeClassName());
355      break;
356    }
357  }
358
359  return ET;
360}
361
362RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
363  llvm::StringRef TypeName;
364  if (NormalizeType(T, TypeName))
365    return Create(Context, T, TypeName);
366  else
367    return NULL;
368}
369
370RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
371                                           const clang::VarDecl *VD) {
372  return RSExportType::Create(Context, GetTypeOfDecl(VD));
373}
374
375size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
376  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
377      ET->getLLVMType());
378}
379
380size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
381  if (ET->getClass() == RSExportType::ExportClassRecord)
382    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
383  else
384    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
385        ET->getLLVMType());
386}
387
388RSExportType::RSExportType(RSContext *Context,
389                           ExportClass Class,
390                           const llvm::StringRef &Name)
391    : RSExportable(Context, RSExportable::EX_TYPE),
392      mClass(Class),
393      // Make a copy on Name since memory stored @Name is either allocated in
394      // ASTContext or allocated in GetTypeName which will be destroyed later.
395      mName(Name.data(), Name.size()),
396      mLLVMType(NULL) {
397  // Don't cache the type whose name start with '<'. Those type failed to
398  // get their name since constructing their name in GetTypeName() requiring
399  // complicated work.
400  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
401    // TODO(zonr): Need to check whether the insertion is successful or not.
402    Context->insertExportType(llvm::StringRef(Name), this);
403  return;
404}
405
406void RSExportType::keep() {
407  // Invalidate converted LLVM type.
408  mLLVMType = NULL;
409  RSExportable::keep();
410  return;
411}
412
413bool RSExportType::equals(const RSExportable *E) const {
414  CHECK_PARENT_EQUALITY(RSExportable, E);
415  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
416}
417
418/************************** RSExportPrimitiveType **************************/
419llvm::ManagedStatic<RSExportPrimitiveType::RSObjectTypeMapTy>
420RSExportPrimitiveType::RSObjectTypeMap;
421
422llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
423
424bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
425  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
426    return true;
427  else
428    return false;
429}
430
431RSExportPrimitiveType::DataType
432RSExportPrimitiveType::GetRSObjectType(const llvm::StringRef &TypeName) {
433  if (TypeName.empty())
434    return DataTypeUnknown;
435
436  if (RSObjectTypeMap->empty()) {
437#define ENUM_RS_OBJECT_TYPE(type, cname)                            \
438    RSObjectTypeMap->GetOrCreateValue(cname, DataType ## type);
439#include "RSObjectTypeEnums.inc"
440#undef ENUM_RS_OBJECT_TYPE
441  }
442
443  RSObjectTypeMapTy::const_iterator I = RSObjectTypeMap->find(TypeName);
444  if (I == RSObjectTypeMap->end())
445    return DataTypeUnknown;
446  else
447    return I->getValue();
448}
449
450RSExportPrimitiveType::DataType
451RSExportPrimitiveType::GetRSObjectType(const clang::Type *T) {
452  T = GET_CANONICAL_TYPE(T);
453  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
454    return DataTypeUnknown;
455
456  return GetRSObjectType( RSExportType::GetTypeName(T) );
457}
458
459const size_t RSExportPrimitiveType::SizeOfDataTypeInBits[] = {
460#define ENUM_RS_DATA_TYPE(type, cname, bits)  \
461  bits,
462#include "RSDataTypeEnums.inc"
463#undef ENUM_RS_DATA_TYPE
464  0   // DataTypeMax
465};
466
467size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
468  assert(((EPT->getType() > DataTypeUnknown) &&
469          (EPT->getType() < DataTypeMax)) &&
470         "RSExportPrimitiveType::GetSizeInBits : unknown data type");
471  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
472}
473
474RSExportPrimitiveType::DataType
475RSExportPrimitiveType::GetDataType(const clang::Type *T) {
476  if (T == NULL)
477    return DataTypeUnknown;
478
479  switch (T->getTypeClass()) {
480    case clang::Type::Builtin: {
481      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
482      switch (BT->getKind()) {
483#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
484        case builtin_type: {                                  \
485          return DataType ## type;                            \
486        }
487#include "RSClangBuiltinEnums.inc"
488#undef ENUM_SUPPORT_BUILTIN_TYPE
489        // The size of type WChar depend on platform so we abandon the support
490        // to them.
491        default: {
492          fprintf(stderr, "RSExportPrimitiveType::GetDataType : unsupported "
493                          "built-in type '%s'\n.", T->getTypeClassName());
494          break;
495        }
496      }
497      break;
498    }
499    case clang::Type::Record: {
500      // must be RS object type
501      return RSExportPrimitiveType::GetRSObjectType(T);
502      break;
503    }
504    default: {
505      fprintf(stderr, "RSExportPrimitiveType::GetDataType : type '%s' is not "
506                      "supported primitive type\n", T->getTypeClassName());
507      break;
508    }
509  }
510
511  return DataTypeUnknown;
512}
513
514RSExportPrimitiveType
515*RSExportPrimitiveType::Create(RSContext *Context,
516                               const clang::Type *T,
517                               const llvm::StringRef &TypeName,
518                               DataKind DK,
519                               bool Normalized) {
520  DataType DT = GetDataType(T);
521
522  if ((DT == DataTypeUnknown) || TypeName.empty())
523    return NULL;
524  else
525    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
526                                     DT, DK, Normalized);
527}
528
529RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
530                                                     const clang::Type *T,
531                                                     DataKind DK) {
532  llvm::StringRef TypeName;
533  if (RSExportType::NormalizeType(T, TypeName) && IsPrimitiveType(T))
534    return Create(Context, T, TypeName, DK);
535  else
536    return NULL;
537}
538
539const llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
540  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
541
542  if (isRSObjectType()) {
543    // struct {
544    //   int *p;
545    // } __attribute__((packed, aligned(pointer_size)))
546    //
547    // which is
548    //
549    // <{ [1 x i32] }> in LLVM
550    //
551    if (RSObjectLLVMType == NULL) {
552      std::vector<const llvm::Type *> Elements;
553      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
554      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
555    }
556    return RSObjectLLVMType;
557  }
558
559  switch (mType) {
560    case DataTypeFloat32: {
561      return llvm::Type::getFloatTy(C);
562      break;
563    }
564    case DataTypeFloat64: {
565      return llvm::Type::getDoubleTy(C);
566      break;
567    }
568    case DataTypeBoolean: {
569      return llvm::Type::getInt1Ty(C);
570      break;
571    }
572    case DataTypeSigned8:
573    case DataTypeUnsigned8: {
574      return llvm::Type::getInt8Ty(C);
575      break;
576    }
577    case DataTypeSigned16:
578    case DataTypeUnsigned16:
579    case DataTypeUnsigned565:
580    case DataTypeUnsigned5551:
581    case DataTypeUnsigned4444: {
582      return llvm::Type::getInt16Ty(C);
583      break;
584    }
585    case DataTypeSigned32:
586    case DataTypeUnsigned32: {
587      return llvm::Type::getInt32Ty(C);
588      break;
589    }
590    case DataTypeSigned64:
591    case DataTypeUnsigned64: {
592      return llvm::Type::getInt64Ty(C);
593      break;
594    }
595    default: {
596      assert(false && "Unknown data type");
597    }
598  }
599
600  return NULL;
601}
602
603bool RSExportPrimitiveType::equals(const RSExportable *E) const {
604  CHECK_PARENT_EQUALITY(RSExportType, E);
605  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
606}
607
608/**************************** RSExportPointerType ****************************/
609
610const clang::Type *RSExportPointerType::IntegerType = NULL;
611
612RSExportPointerType
613*RSExportPointerType::Create(RSContext *Context,
614                             const clang::PointerType *PT,
615                             const llvm::StringRef &TypeName) {
616  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
617  const RSExportType *PointeeET;
618
619  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
620    PointeeET = RSExportType::Create(Context, PointeeType);
621  } else {
622    // Double or higher dimension of pointer, export as int*
623    assert(IntegerType != NULL && "Built-in integer type is not set");
624    PointeeET = RSExportPrimitiveType::Create(Context, IntegerType);
625  }
626
627  if (PointeeET == NULL) {
628    fprintf(stderr, "Failed to create type for pointee");
629    return NULL;
630  }
631
632  return new RSExportPointerType(Context, TypeName, PointeeET);
633}
634
635const llvm::Type *RSExportPointerType::convertToLLVMType() const {
636  const llvm::Type *PointeeType = mPointeeType->getLLVMType();
637  return llvm::PointerType::getUnqual(PointeeType);
638}
639
640void RSExportPointerType::keep() {
641  const_cast<RSExportType*>(mPointeeType)->keep();
642  RSExportType::keep();
643}
644
645bool RSExportPointerType::equals(const RSExportable *E) const {
646  CHECK_PARENT_EQUALITY(RSExportType, E);
647  return (static_cast<const RSExportPointerType*>(E)
648              ->getPointeeType()->equals(getPointeeType()));
649}
650
651/***************************** RSExportVectorType *****************************/
652llvm::StringRef
653RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
654  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
655
656  if ((ElementType->getTypeClass() != clang::Type::Builtin))
657    return llvm::StringRef();
658
659  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType,
660                                                  ElementType);
661  if ((EVT->getNumElements() < 1) ||
662      (EVT->getNumElements() > 4))
663    return llvm::StringRef();
664
665  switch (BT->getKind()) {
666    // Compiler is smart enough to optimize following *big if branches* since
667    // they all become "constant comparison" after macro expansion
668#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
669    case builtin_type: {                                      \
670      const char *Name[] = { cname"2", cname"3", cname"4" };  \
671      return Name[EVT->getNumElements() - 2];                 \
672      break;                                                  \
673    }
674#include "RSClangBuiltinEnums.inc"
675#undef ENUM_SUPPORT_BUILTIN_TYPE
676    default: {
677      return llvm::StringRef();
678    }
679  }
680}
681
682RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
683                                               const clang::ExtVectorType *EVT,
684                                               const llvm::StringRef &TypeName,
685                                               DataKind DK,
686                                               bool Normalized) {
687  assert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
688
689  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
690  RSExportPrimitiveType::DataType DT =
691      RSExportPrimitiveType::GetDataType(ElementType);
692
693  if (DT != RSExportPrimitiveType::DataTypeUnknown)
694    return new RSExportVectorType(Context,
695                                  TypeName,
696                                  DT,
697                                  DK,
698                                  Normalized,
699                                  EVT->getNumElements());
700  else
701    fprintf(stderr, "RSExportVectorType::Create : unsupported base element "
702                    "type\n");
703  return NULL;
704}
705
706const llvm::Type *RSExportVectorType::convertToLLVMType() const {
707  const llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
708  return llvm::VectorType::get(ElementType, getNumElement());
709}
710
711bool RSExportVectorType::equals(const RSExportable *E) const {
712  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
713  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
714              == getNumElement());
715}
716
717/***************************** RSExportMatrixType *****************************/
718RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
719                                               const clang::RecordType *RT,
720                                               const llvm::StringRef &TypeName,
721                                               unsigned Dim) {
722  assert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
723  assert((Dim > 1) && "Invalid dimension of matrix");
724
725  // Check whether the struct rs_matrix is in our expected form (but assume it's
726  // correct if we're not sure whether it's correct or not)
727  const clang::RecordDecl* RD = RT->getDecl();
728  RD = RD->getDefinition();
729  if (RD != NULL) {
730    // Find definition, perform further examination
731    if (RD->field_empty()) {
732      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
733                      "must have 1 field for saving values", TypeName.data());
734      return NULL;
735    }
736
737    clang::RecordDecl::field_iterator FIT = RD->field_begin();
738    const clang::FieldDecl *FD = *FIT;
739    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
740    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
741      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
742                      "first field should be an array with constant size",
743              TypeName.data());
744      return NULL;
745    }
746    const clang::ConstantArrayType *CAT =
747      static_cast<const clang::ConstantArrayType *>(FT);
748    const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
749    if ((ElementType == NULL) ||
750        (ElementType->getTypeClass() != clang::Type::Builtin) ||
751        (static_cast<const clang::BuiltinType *>(ElementType)->getKind()
752          != clang::BuiltinType::Float)) {
753      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
754                      "first field should be a float array", TypeName.data());
755      return NULL;
756    }
757
758    if (CAT->getSize() != Dim * Dim) {
759      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
760                      "first field should be an array with size %d",
761              TypeName.data(), Dim * Dim);
762      return NULL;
763    }
764
765    FIT++;
766    if (FIT != RD->field_end()) {
767      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
768                      "must have exactly 1 field", TypeName.data());
769      return NULL;
770    }
771  }
772
773  return new RSExportMatrixType(Context, TypeName, Dim);
774}
775
776const llvm::Type *RSExportMatrixType::convertToLLVMType() const {
777  // Construct LLVM type:
778  // struct {
779  //  float X[mDim * mDim];
780  // }
781
782  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
783  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
784                                            mDim * mDim);
785  return llvm::StructType::get(C, X, NULL);
786}
787
788bool RSExportMatrixType::equals(const RSExportable *E) const {
789  CHECK_PARENT_EQUALITY(RSExportType, E);
790  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
791}
792
793/************************* RSExportConstantArrayType *************************/
794RSExportConstantArrayType
795*RSExportConstantArrayType::Create(RSContext *Context,
796                                   const clang::ConstantArrayType *CAT) {
797  assert(CAT != NULL && CAT->getTypeClass() == clang::Type::ConstantArray);
798
799  assert((CAT->getSize().getActiveBits() < 32) && "array too large");
800
801  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
802  assert((Size > 0) && "Constant array should have size greater than 0");
803
804  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
805  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
806
807  if (ElementET == NULL) {
808    fprintf(stderr, "RSExportConstantArrayType::Create : failed to create "
809                    "RSExportType for array element.\n");
810    return NULL;
811  }
812
813  return new RSExportConstantArrayType(Context,
814                                       ElementET,
815                                       Size);
816}
817
818const llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
819  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
820}
821
822void RSExportConstantArrayType::keep() {
823  const_cast<RSExportType*>(mElementType)->keep();
824  RSExportType::keep();
825  return;
826}
827
828bool RSExportConstantArrayType::equals(const RSExportable *E) const {
829  CHECK_PARENT_EQUALITY(RSExportType, E);
830  return ((static_cast<const RSExportConstantArrayType*>(E)
831              ->getSize() == getSize()) && (mElementType->equals(E)));
832}
833
834/**************************** RSExportRecordType ****************************/
835RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
836                                               const clang::RecordType *RT,
837                                               const llvm::StringRef &TypeName,
838                                               bool mIsArtificial) {
839  assert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
840
841  const clang::RecordDecl *RD = RT->getDecl();
842  assert(RD->isStruct());
843
844  RD = RD->getDefinition();
845  if (RD == NULL) {
846    // TODO(zonr): warn that actual struct definition isn't declared in this
847    //             moudle.
848    fprintf(stderr, "RSExportRecordType::Create : this struct is not defined "
849                    "in this module.");
850    return NULL;
851  }
852
853  // Struct layout construct by clang. We rely on this for obtaining the
854  // alloc size of a struct and offset of every field in that struct.
855  const clang::ASTRecordLayout *RL =
856      &Context->getASTContext()->getASTRecordLayout(RD);
857  assert((RL != NULL) && "Failed to retrieve the struct layout from Clang.");
858
859  RSExportRecordType *ERT =
860      new RSExportRecordType(Context,
861                             TypeName,
862                             RD->hasAttr<clang::PackedAttr>(),
863                             mIsArtificial,
864                             (RL->getSize() >> 3));
865  unsigned int Index = 0;
866
867  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
868           FE = RD->field_end();
869       FI != FE;
870       FI++, Index++) {
871#define FAILED_CREATE_FIELD(err)    do {         \
872      if (*err)                                                          \
873        fprintf(stderr, \
874                "RSExportRecordType::Create : failed to create field (%s)\n", \
875                err);                                                   \
876      delete ERT;                                                       \
877      return NULL;                                                      \
878    } while (false)
879
880    // FIXME: All fields should be primitive type
881    assert((*FI)->getKind() == clang::Decl::Field);
882    clang::FieldDecl *FD = *FI;
883
884    // We don't support bit field
885    //
886    // TODO(zonr): allow bitfield with size 8, 16, 32
887    if (FD->isBitField())
888      FAILED_CREATE_FIELD("bit field is not supported");
889
890    // Type
891    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
892
893    if (ET != NULL)
894      ERT->mFields.push_back(
895          new Field(ET, FD->getName(), ERT,
896                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
897    else
898      FAILED_CREATE_FIELD(FD->getName().str().c_str());
899#undef FAILED_CREATE_FIELD
900  }
901
902  return ERT;
903}
904
905const llvm::Type *RSExportRecordType::convertToLLVMType() const {
906  std::vector<const llvm::Type*> FieldTypes;
907
908  for (const_field_iterator FI = fields_begin(),
909           FE = fields_end();
910       FI != FE;
911       FI++) {
912    const Field *F = *FI;
913    const RSExportType *FET = F->getType();
914
915    FieldTypes.push_back(FET->getLLVMType());
916  }
917
918  return llvm::StructType::get(getRSContext()->getLLVMContext(),
919                               FieldTypes,
920                               mIsPacked);
921}
922
923void RSExportRecordType::keep() {
924  for (std::list<const Field*>::iterator I = mFields.begin(),
925          E = mFields.end();
926       I != E;
927       I++) {
928    const_cast<RSExportType*>((*I)->getType())->keep();
929  }
930  RSExportType::keep();
931  return;
932}
933
934bool RSExportRecordType::equals(const RSExportable *E) const {
935  CHECK_PARENT_EQUALITY(RSExportType, E);
936
937  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
938
939  if (ERT->getFields().size() != getFields().size())
940    return false;
941
942  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
943
944  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
945    if (!(*AI)->getType()->equals((*BI)->getType()))
946      return false;
947    AI++;
948    BI++;
949  }
950
951  return true;
952}
953