slang_rs_export_type.cpp revision dd6206bb61bf8df2ed6b643abe8a29c48a315685
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/RecordLayout.h"
23
24#include "llvm/ADT/StringExtras.h"
25
26#include "llvm/DerivedTypes.h"
27
28#include "llvm/Target/TargetData.h"
29
30#include "llvm/Type.h"
31
32#include "slang_rs_context.h"
33#include "slang_rs_export_element.h"
34#include "slang_rs_type_spec.h"
35
36#define CHECK_PARENT_EQUALITY(ParentClass, E) \
37  if (!ParentClass::equals(E))                \
38    return false;
39
40namespace slang {
41
42namespace {
43
44const clang::Type *TypeExportableHelper(
45    const clang::Type *T,
46    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
47    clang::Diagnostic *Diags,
48    clang::SourceManager *SM,
49    const clang::VarDecl *VD,
50    const clang::RecordDecl *TopLevelRecord) {
51  // Normalize first
52  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
53    return NULL;
54
55  if (SPS.count(T))
56    return T;
57
58  switch (T->getTypeClass()) {
59    case clang::Type::Builtin: {
60      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
61
62      switch (BT->getKind()) {
63#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
64        case builtin_type:
65#include "RSClangBuiltinEnums.inc"
66          return T;
67        default: {
68          return NULL;
69        }
70      }
71    }
72    case clang::Type::Record: {
73      if (RSExportPrimitiveType::GetRSSpecificType(T) !=
74          RSExportPrimitiveType::DataTypeUnknown)
75        return T;  // RS object type, no further checks are needed
76
77      // Check internal struct
78      clang::RecordDecl *RD = NULL;
79      if (T->isUnionType()) {
80        RD = T->getAsUnionType()->getDecl();
81        if (Diags && SM) {
82          Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
83                        Diags->getCustomDiagID(clang::Diagnostic::Error,
84                                               "unions cannot "
85                                               "be exported: '%0'"))
86              << RD->getName();
87        }
88        return NULL;
89      } else if (!T->isStructureType()) {
90        assert(false && "Unknown type cannot be exported");
91        return NULL;
92      }
93
94      RD = T->getAsStructureType()->getDecl();
95      if (RD != NULL)
96        RD = RD->getDefinition();
97
98      if (!TopLevelRecord) {
99        TopLevelRecord = RD;
100      }
101      if (RD->getName().empty()) {
102        if (Diags && SM) {
103          Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
104                        Diags->getCustomDiagID(clang::Diagnostic::Error,
105                                               "anonymous structures cannot "
106                                               "be exported"));
107        }
108        return NULL;
109      }
110
111      // Fast check
112      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
113        return NULL;
114
115      // Insert myself into checking set
116      SPS.insert(T);
117
118      // Check all element
119      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
120               FE = RD->field_end();
121           FI != FE;
122           FI++) {
123        const clang::FieldDecl *FD = *FI;
124        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
125        FT = GET_CANONICAL_TYPE(FT);
126
127        if (!TypeExportableHelper(FT, SPS, Diags, SM, VD, TopLevelRecord)) {
128          return NULL;
129        }
130      }
131
132      return T;
133    }
134    case clang::Type::Pointer: {
135      if (TopLevelRecord) {
136        if (Diags && SM) {
137          Diags->Report(clang::FullSourceLoc(TopLevelRecord->getLocation(),
138                            *SM),
139                        Diags->getCustomDiagID(clang::Diagnostic::Error,
140                                               "structures containing pointers "
141                                               "cannot be exported: '%0'"))
142              << TopLevelRecord->getName();
143        }
144        return NULL;
145      }
146      const clang::PointerType *PT = UNSAFE_CAST_TYPE(clang::PointerType, T);
147      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
148
149      if (PointeeType->getTypeClass() == clang::Type::Pointer)
150        return T;
151      // We don't support pointer with array-type pointee or unsupported pointee
152      // type
153      if (PointeeType->isArrayType() ||
154          (TypeExportableHelper(PointeeType, SPS, Diags, SM, VD,
155                                TopLevelRecord) == NULL))
156        return NULL;
157      else
158        return T;
159    }
160    case clang::Type::ExtVector: {
161      const clang::ExtVectorType *EVT =
162          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
163      // Only vector with size 2, 3 and 4 are supported.
164      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
165        return NULL;
166
167      // Check base element type
168      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
169
170      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
171          (TypeExportableHelper(ElementType, SPS, Diags, SM, VD,
172                                TopLevelRecord) == NULL))
173        return NULL;
174      else
175        return T;
176    }
177    case clang::Type::ConstantArray: {
178      const clang::ConstantArrayType *CAT =
179          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
180
181      // Check size
182      if (CAT->getSize().getActiveBits() > 32) {
183        fprintf(stderr, "RSExportConstantArrayType::Create : array with too "
184                        "large size (> 2^32).\n");
185        return NULL;
186      }
187      // Check element type
188      const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
189      if (ElementType->isArrayType()) {
190        fprintf(stderr, "RSExportType::TypeExportableHelper : constant array "
191                        "with 2 or higher dimension of constant is not "
192                        "supported.\n");
193        return NULL;
194      } else if (ElementType->isExtVectorType()) {
195        const clang::ExtVectorType *EVT =
196            static_cast<const clang::ExtVectorType*>(ElementType);
197        unsigned numElements = EVT->getNumElements();
198
199        const clang::Type *BaseElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
200        if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
201          if (Diags && SM) {
202            Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
203                                                 "vectors of non-primitive "
204                                                 "types cannot be exported"));
205
206            if (TopLevelRecord) {
207              Diags->Report(clang::FullSourceLoc(TopLevelRecord->getLocation(),
208                                *SM),
209                            Diags->getCustomDiagID(clang::Diagnostic::Error,
210                                                 "vectors of non-primitive "
211                                                 "types cannot be exported: "
212                                                 "'%0'"))
213                   << TopLevelRecord->getName();
214            } else if (VD) {
215              Diags->Report(clang::FullSourceLoc(VD->getLocation(), *SM),
216                            Diags->getCustomDiagID(clang::Diagnostic::Error,
217                                                 "vectors of non-primitive "
218                                                 "types cannot be exported: "
219                                                 "'%0'"))
220                   << VD->getName();
221            } else {
222              assert(false && "Variables should be validated before exporting");
223            }
224          }
225          return NULL;
226        }
227
228        if (numElements == 3 && CAT->getSize() != 1) {
229          if (Diags && SM) {
230            if (TopLevelRecord) {
231              Diags->Report(clang::FullSourceLoc(TopLevelRecord->getLocation(),
232                                *SM),
233                            Diags->getCustomDiagID(clang::Diagnostic::Error,
234                                                   "arrays of width 3 vector "
235                                                   "types cannot be exported: "
236                                                   "'%0'"))
237                   << TopLevelRecord->getName();
238            } else if (VD) {
239              Diags->Report(clang::FullSourceLoc(VD->getLocation(), *SM),
240                            Diags->getCustomDiagID(clang::Diagnostic::Error,
241                                                   "arrays of width 3 vector "
242                                                   "types cannot be exported: "
243                                                   "'%0'"))
244                   << VD->getName();
245            } else {
246              assert(false && "Variables should be validated before exporting");
247            }
248          }
249          return NULL;
250        }
251      }
252
253      if (TypeExportableHelper(ElementType, SPS, Diags, SM, VD,
254                               TopLevelRecord) == NULL)
255        return NULL;
256      else
257        return T;
258    }
259    default: {
260      return NULL;
261    }
262  }
263}
264
265// Return the type that can be used to create RSExportType, will always return
266// the canonical type
267// If the Type T is not exportable, this function returns NULL. Diags and SM
268// are used to generate proper Clang diagnostic messages when a
269// non-exportable type is detected. TopLevelRecord is used to capture the
270// highest struct (in the case of a nested hierarchy) for detecting other
271// types that cannot be exported (mostly pointers within a struct).
272static const clang::Type *TypeExportable(const clang::Type *T,
273                                         clang::Diagnostic *Diags,
274                                         clang::SourceManager *SM,
275                                         const clang::VarDecl *VD) {
276  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
277      llvm::SmallPtrSet<const clang::Type*, 8>();
278
279  return TypeExportableHelper(T, SPS, Diags, SM, VD, NULL);
280}
281
282}  // namespace
283
284/****************************** RSExportType ******************************/
285bool RSExportType::NormalizeType(const clang::Type *&T,
286                                 llvm::StringRef &TypeName,
287                                 clang::Diagnostic *Diags,
288                                 clang::SourceManager *SM,
289                                 const clang::VarDecl *VD) {
290  if ((T = TypeExportable(T, Diags, SM, VD)) == NULL) {
291    return false;
292  }
293  // Get type name
294  TypeName = RSExportType::GetTypeName(T);
295  if (TypeName.empty()) {
296    if (Diags && SM) {
297      Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
298                                           "anonymous types cannot "
299                                           "be exported"));
300    }
301    return false;
302  }
303
304  return true;
305}
306
307const clang::Type
308*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
309  if (DD) {
310    clang::QualType T;
311    if (DD->getTypeSourceInfo())
312      T = DD->getTypeSourceInfo()->getType();
313    else
314      T = DD->getType();
315
316    if (T.isNull())
317      return NULL;
318    else
319      return T.getTypePtr();
320  }
321  return NULL;
322}
323
324llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
325  T = GET_CANONICAL_TYPE(T);
326  if (T == NULL)
327    return llvm::StringRef();
328
329  switch (T->getTypeClass()) {
330    case clang::Type::Builtin: {
331      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
332
333      switch (BT->getKind()) {
334#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
335        case builtin_type:                                    \
336          return cname;                                       \
337        break;
338#include "RSClangBuiltinEnums.inc"
339        default: {
340          assert(false && "Unknown data type of the builtin");
341          break;
342        }
343      }
344      break;
345    }
346    case clang::Type::Record: {
347      clang::RecordDecl *RD;
348      if (T->isStructureType()) {
349        RD = T->getAsStructureType()->getDecl();
350      } else {
351        break;
352      }
353
354      llvm::StringRef Name = RD->getName();
355      if (Name.empty()) {
356          if (RD->getTypedefForAnonDecl() != NULL)
357            Name = RD->getTypedefForAnonDecl()->getName();
358
359          if (Name.empty())
360            // Try to find a name from redeclaration (i.e. typedef)
361            for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
362                     RE = RD->redecls_end();
363                 RI != RE;
364                 RI++) {
365              assert(*RI != NULL && "cannot be NULL object");
366
367              Name = (*RI)->getName();
368              if (!Name.empty())
369                break;
370            }
371      }
372      return Name;
373    }
374    case clang::Type::Pointer: {
375      // "*" plus pointee name
376      const clang::Type *PT = GET_POINTEE_TYPE(T);
377      llvm::StringRef PointeeName;
378      if (NormalizeType(PT, PointeeName, NULL, NULL, NULL)) {
379        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
380        Name[0] = '*';
381        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
382        Name[PointeeName.size() + 1] = '\0';
383        return Name;
384      }
385      break;
386    }
387    case clang::Type::ExtVector: {
388      const clang::ExtVectorType *EVT =
389          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
390      return RSExportVectorType::GetTypeName(EVT);
391      break;
392    }
393    case clang::Type::ConstantArray : {
394      // Construct name for a constant array is too complicated.
395      return DUMMY_TYPE_NAME_FOR_RS_CONSTANT_ARRAY_TYPE;
396    }
397    default: {
398      break;
399    }
400  }
401
402  return llvm::StringRef();
403}
404
405
406RSExportType *RSExportType::Create(RSContext *Context,
407                                   const clang::Type *T,
408                                   const llvm::StringRef &TypeName) {
409  // Lookup the context to see whether the type was processed before.
410  // Newly created RSExportType will insert into context
411  // in RSExportType::RSExportType()
412  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
413
414  if (ETI != Context->export_types_end())
415    return ETI->second;
416
417  RSExportType *ET = NULL;
418  switch (T->getTypeClass()) {
419    case clang::Type::Record: {
420      RSExportPrimitiveType::DataType dt =
421          RSExportPrimitiveType::GetRSSpecificType(TypeName);
422      switch (dt) {
423        case RSExportPrimitiveType::DataTypeUnknown: {
424          // User-defined types
425          ET = RSExportRecordType::Create(Context,
426                                          T->getAsStructureType(),
427                                          TypeName);
428          break;
429        }
430        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
431          // 2 x 2 Matrix type
432          ET = RSExportMatrixType::Create(Context,
433                                          T->getAsStructureType(),
434                                          TypeName,
435                                          2);
436          break;
437        }
438        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
439          // 3 x 3 Matrix type
440          ET = RSExportMatrixType::Create(Context,
441                                          T->getAsStructureType(),
442                                          TypeName,
443                                          3);
444          break;
445        }
446        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
447          // 4 x 4 Matrix type
448          ET = RSExportMatrixType::Create(Context,
449                                          T->getAsStructureType(),
450                                          TypeName,
451                                          4);
452          break;
453        }
454        default: {
455          // Others are primitive types
456          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
457          break;
458        }
459      }
460      break;
461    }
462    case clang::Type::Builtin: {
463      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
464      break;
465    }
466    case clang::Type::Pointer: {
467      ET = RSExportPointerType::Create(Context,
468                                       UNSAFE_CAST_TYPE(clang::PointerType, T),
469                                       TypeName);
470      // FIXME: free the name (allocated in RSExportType::GetTypeName)
471      delete [] TypeName.data();
472      break;
473    }
474    case clang::Type::ExtVector: {
475      ET = RSExportVectorType::Create(Context,
476                                      UNSAFE_CAST_TYPE(clang::ExtVectorType, T),
477                                      TypeName);
478      break;
479    }
480    case clang::Type::ConstantArray: {
481      ET = RSExportConstantArrayType::Create(
482              Context,
483              UNSAFE_CAST_TYPE(clang::ConstantArrayType, T));
484      break;
485    }
486    default: {
487      // TODO(zonr): warn that type is not exportable.
488      fprintf(stderr,
489              "RSExportType::Create : type '%s' is not exportable\n",
490              T->getTypeClassName());
491      break;
492    }
493  }
494
495  return ET;
496}
497
498RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
499  llvm::StringRef TypeName;
500  if (NormalizeType(T, TypeName, NULL, NULL, NULL))
501    return Create(Context, T, TypeName);
502  else
503    return NULL;
504}
505
506RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
507                                           const clang::VarDecl *VD) {
508  return RSExportType::Create(Context, GetTypeOfDecl(VD));
509}
510
511size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
512  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
513      ET->getLLVMType());
514}
515
516size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
517  if (ET->getClass() == RSExportType::ExportClassRecord)
518    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
519  else
520    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
521        ET->getLLVMType());
522}
523
524RSExportType::RSExportType(RSContext *Context,
525                           ExportClass Class,
526                           const llvm::StringRef &Name)
527    : RSExportable(Context, RSExportable::EX_TYPE),
528      mClass(Class),
529      // Make a copy on Name since memory stored @Name is either allocated in
530      // ASTContext or allocated in GetTypeName which will be destroyed later.
531      mName(Name.data(), Name.size()),
532      mLLVMType(NULL),
533      mSpecType(NULL) {
534  // Don't cache the type whose name start with '<'. Those type failed to
535  // get their name since constructing their name in GetTypeName() requiring
536  // complicated work.
537  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
538    // TODO(zonr): Need to check whether the insertion is successful or not.
539    Context->insertExportType(llvm::StringRef(Name), this);
540  return;
541}
542
543bool RSExportType::keep() {
544  if (!RSExportable::keep())
545    return false;
546  // Invalidate converted LLVM type.
547  mLLVMType = NULL;
548  return true;
549}
550
551bool RSExportType::equals(const RSExportable *E) const {
552  CHECK_PARENT_EQUALITY(RSExportable, E);
553  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
554}
555
556RSExportType::~RSExportType() {
557  delete mSpecType;
558}
559
560/************************** RSExportPrimitiveType **************************/
561llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
562RSExportPrimitiveType::RSSpecificTypeMap;
563
564llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
565
566bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
567  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
568    return true;
569  else
570    return false;
571}
572
573RSExportPrimitiveType::DataType
574RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
575  if (TypeName.empty())
576    return DataTypeUnknown;
577
578  if (RSSpecificTypeMap->empty()) {
579#define ENUM_RS_MATRIX_TYPE(type, cname, dim)                       \
580    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
581#include "RSMatrixTypeEnums.inc"
582#define ENUM_RS_OBJECT_TYPE(type, cname)                            \
583    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
584#include "RSObjectTypeEnums.inc"
585  }
586
587  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
588  if (I == RSSpecificTypeMap->end())
589    return DataTypeUnknown;
590  else
591    return I->getValue();
592}
593
594RSExportPrimitiveType::DataType
595RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
596  T = GET_CANONICAL_TYPE(T);
597  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
598    return DataTypeUnknown;
599
600  return GetRSSpecificType( RSExportType::GetTypeName(T) );
601}
602
603bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
604  return ((DT >= FirstRSMatrixType) && (DT <= LastRSMatrixType));
605}
606
607bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
608  return ((DT >= FirstRSObjectType) && (DT <= LastRSObjectType));
609}
610
611const size_t RSExportPrimitiveType::SizeOfDataTypeInBits[] = {
612#define ENUM_RS_DATA_TYPE(type, cname, bits)  \
613  bits,
614#include "RSDataTypeEnums.inc"
615  0   // DataTypeMax
616};
617
618size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
619  assert(((EPT->getType() > DataTypeUnknown) &&
620          (EPT->getType() < DataTypeMax)) &&
621         "RSExportPrimitiveType::GetSizeInBits : unknown data type");
622  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
623}
624
625RSExportPrimitiveType::DataType
626RSExportPrimitiveType::GetDataType(const clang::Type *T) {
627  if (T == NULL)
628    return DataTypeUnknown;
629
630  switch (T->getTypeClass()) {
631    case clang::Type::Builtin: {
632      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
633      switch (BT->getKind()) {
634#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
635        case builtin_type: {                                  \
636          return DataType ## type;                            \
637        }
638#include "RSClangBuiltinEnums.inc"
639        // The size of type WChar depend on platform so we abandon the support
640        // to them.
641        default: {
642          fprintf(stderr, "RSExportPrimitiveType::GetDataType : unsupported "
643                          "built-in type '%s'\n.", T->getTypeClassName());
644          break;
645        }
646      }
647      break;
648    }
649    case clang::Type::Record: {
650      // must be RS object type
651      return RSExportPrimitiveType::GetRSSpecificType(T);
652    }
653    default: {
654      fprintf(stderr, "RSExportPrimitiveType::GetDataType : type '%s' is not "
655                      "supported primitive type\n", T->getTypeClassName());
656      break;
657    }
658  }
659
660  return DataTypeUnknown;
661}
662
663RSExportPrimitiveType
664*RSExportPrimitiveType::Create(RSContext *Context,
665                               const clang::Type *T,
666                               const llvm::StringRef &TypeName,
667                               DataKind DK,
668                               bool Normalized) {
669  DataType DT = GetDataType(T);
670
671  if ((DT == DataTypeUnknown) || TypeName.empty())
672    return NULL;
673  else
674    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
675                                     DT, DK, Normalized);
676}
677
678RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
679                                                     const clang::Type *T,
680                                                     DataKind DK) {
681  llvm::StringRef TypeName;
682  if (RSExportType::NormalizeType(T, TypeName, NULL, NULL, NULL) &&
683      IsPrimitiveType(T)) {
684    return Create(Context, T, TypeName, DK);
685  } else {
686    return NULL;
687  }
688}
689
690const llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
691  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
692
693  if (isRSObjectType()) {
694    // struct {
695    //   int *p;
696    // } __attribute__((packed, aligned(pointer_size)))
697    //
698    // which is
699    //
700    // <{ [1 x i32] }> in LLVM
701    //
702    if (RSObjectLLVMType == NULL) {
703      std::vector<const llvm::Type *> Elements;
704      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
705      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
706    }
707    return RSObjectLLVMType;
708  }
709
710  switch (mType) {
711    case DataTypeFloat32: {
712      return llvm::Type::getFloatTy(C);
713      break;
714    }
715    case DataTypeFloat64: {
716      return llvm::Type::getDoubleTy(C);
717      break;
718    }
719    case DataTypeBoolean: {
720      return llvm::Type::getInt1Ty(C);
721      break;
722    }
723    case DataTypeSigned8:
724    case DataTypeUnsigned8: {
725      return llvm::Type::getInt8Ty(C);
726      break;
727    }
728    case DataTypeSigned16:
729    case DataTypeUnsigned16:
730    case DataTypeUnsigned565:
731    case DataTypeUnsigned5551:
732    case DataTypeUnsigned4444: {
733      return llvm::Type::getInt16Ty(C);
734      break;
735    }
736    case DataTypeSigned32:
737    case DataTypeUnsigned32: {
738      return llvm::Type::getInt32Ty(C);
739      break;
740    }
741    case DataTypeSigned64:
742    case DataTypeUnsigned64: {
743      return llvm::Type::getInt64Ty(C);
744      break;
745    }
746    default: {
747      assert(false && "Unknown data type");
748    }
749  }
750
751  return NULL;
752}
753
754union RSType *RSExportPrimitiveType::convertToSpecType() const {
755  llvm::OwningPtr<union RSType> ST(new union RSType);
756  RS_TYPE_SET_CLASS(ST, RS_TC_Primitive);
757  // enum RSExportPrimitiveType::DataType is synced with enum RSDataType in
758  // slang_rs_type_spec.h
759  RS_PRIMITIVE_TYPE_SET_DATA_TYPE(ST, getType());
760  return ST.take();
761}
762
763bool RSExportPrimitiveType::equals(const RSExportable *E) const {
764  CHECK_PARENT_EQUALITY(RSExportType, E);
765  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
766}
767
768/**************************** RSExportPointerType ****************************/
769
770const clang::Type *RSExportPointerType::IntegerType = NULL;
771
772RSExportPointerType
773*RSExportPointerType::Create(RSContext *Context,
774                             const clang::PointerType *PT,
775                             const llvm::StringRef &TypeName) {
776  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
777  const RSExportType *PointeeET;
778
779  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
780    PointeeET = RSExportType::Create(Context, PointeeType);
781  } else {
782    // Double or higher dimension of pointer, export as int*
783    assert(IntegerType != NULL && "Built-in integer type is not set");
784    PointeeET = RSExportPrimitiveType::Create(Context, IntegerType);
785  }
786
787  if (PointeeET == NULL) {
788    fprintf(stderr, "Failed to create type for pointee");
789    return NULL;
790  }
791
792  return new RSExportPointerType(Context, TypeName, PointeeET);
793}
794
795const llvm::Type *RSExportPointerType::convertToLLVMType() const {
796  const llvm::Type *PointeeType = mPointeeType->getLLVMType();
797  return llvm::PointerType::getUnqual(PointeeType);
798}
799
800union RSType *RSExportPointerType::convertToSpecType() const {
801  llvm::OwningPtr<union RSType> ST(new union RSType);
802
803  RS_TYPE_SET_CLASS(ST, RS_TC_Pointer);
804  RS_POINTER_TYPE_SET_POINTEE_TYPE(ST, getPointeeType()->getSpecType());
805
806  if (RS_POINTER_TYPE_GET_POINTEE_TYPE(ST) != NULL)
807    return ST.take();
808  else
809    return NULL;
810}
811
812bool RSExportPointerType::keep() {
813  if (!RSExportType::keep())
814    return false;
815  const_cast<RSExportType*>(mPointeeType)->keep();
816  return true;
817}
818
819bool RSExportPointerType::equals(const RSExportable *E) const {
820  CHECK_PARENT_EQUALITY(RSExportType, E);
821  return (static_cast<const RSExportPointerType*>(E)
822              ->getPointeeType()->equals(getPointeeType()));
823}
824
825/***************************** RSExportVectorType *****************************/
826llvm::StringRef
827RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
828  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
829
830  if ((ElementType->getTypeClass() != clang::Type::Builtin))
831    return llvm::StringRef();
832
833  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType,
834                                                  ElementType);
835  if ((EVT->getNumElements() < 1) ||
836      (EVT->getNumElements() > 4))
837    return llvm::StringRef();
838
839  switch (BT->getKind()) {
840    // Compiler is smart enough to optimize following *big if branches* since
841    // they all become "constant comparison" after macro expansion
842#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
843    case builtin_type: {                                      \
844      const char *Name[] = { cname"2", cname"3", cname"4" };  \
845      return Name[EVT->getNumElements() - 2];                 \
846      break;                                                  \
847    }
848#include "RSClangBuiltinEnums.inc"
849    default: {
850      return llvm::StringRef();
851    }
852  }
853}
854
855RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
856                                               const clang::ExtVectorType *EVT,
857                                               const llvm::StringRef &TypeName,
858                                               DataKind DK,
859                                               bool Normalized) {
860  assert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
861
862  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
863  RSExportPrimitiveType::DataType DT =
864      RSExportPrimitiveType::GetDataType(ElementType);
865
866  if (DT != RSExportPrimitiveType::DataTypeUnknown)
867    return new RSExportVectorType(Context,
868                                  TypeName,
869                                  DT,
870                                  DK,
871                                  Normalized,
872                                  EVT->getNumElements());
873  else
874    fprintf(stderr, "RSExportVectorType::Create : unsupported base element "
875                    "type\n");
876  return NULL;
877}
878
879const llvm::Type *RSExportVectorType::convertToLLVMType() const {
880  const llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
881  return llvm::VectorType::get(ElementType, getNumElement());
882}
883
884union RSType *RSExportVectorType::convertToSpecType() const {
885  llvm::OwningPtr<union RSType> ST(new union RSType);
886
887  RS_TYPE_SET_CLASS(ST, RS_TC_Vector);
888  RS_VECTOR_TYPE_SET_ELEMENT_TYPE(ST, getType());
889  RS_VECTOR_TYPE_SET_VECTOR_SIZE(ST, getNumElement());
890
891  return ST.take();
892}
893
894bool RSExportVectorType::equals(const RSExportable *E) const {
895  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
896  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
897              == getNumElement());
898}
899
900/***************************** RSExportMatrixType *****************************/
901RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
902                                               const clang::RecordType *RT,
903                                               const llvm::StringRef &TypeName,
904                                               unsigned Dim) {
905  assert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
906  assert((Dim > 1) && "Invalid dimension of matrix");
907
908  // Check whether the struct rs_matrix is in our expected form (but assume it's
909  // correct if we're not sure whether it's correct or not)
910  const clang::RecordDecl* RD = RT->getDecl();
911  RD = RD->getDefinition();
912  if (RD != NULL) {
913    // Find definition, perform further examination
914    if (RD->field_empty()) {
915      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
916                      "must have 1 field for saving values", TypeName.data());
917      return NULL;
918    }
919
920    clang::RecordDecl::field_iterator FIT = RD->field_begin();
921    const clang::FieldDecl *FD = *FIT;
922    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
923    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
924      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
925                      "first field should be an array with constant size",
926              TypeName.data());
927      return NULL;
928    }
929    const clang::ConstantArrayType *CAT =
930      static_cast<const clang::ConstantArrayType *>(FT);
931    const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
932    if ((ElementType == NULL) ||
933        (ElementType->getTypeClass() != clang::Type::Builtin) ||
934        (static_cast<const clang::BuiltinType *>(ElementType)->getKind()
935          != clang::BuiltinType::Float)) {
936      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
937                      "first field should be a float array", TypeName.data());
938      return NULL;
939    }
940
941    if (CAT->getSize() != Dim * Dim) {
942      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
943                      "first field should be an array with size %d",
944              TypeName.data(), Dim * Dim);
945      return NULL;
946    }
947
948    FIT++;
949    if (FIT != RD->field_end()) {
950      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
951                      "must have exactly 1 field", TypeName.data());
952      return NULL;
953    }
954  }
955
956  return new RSExportMatrixType(Context, TypeName, Dim);
957}
958
959const llvm::Type *RSExportMatrixType::convertToLLVMType() const {
960  // Construct LLVM type:
961  // struct {
962  //  float X[mDim * mDim];
963  // }
964
965  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
966  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
967                                            mDim * mDim);
968  return llvm::StructType::get(C, X, NULL);
969}
970
971union RSType *RSExportMatrixType::convertToSpecType() const {
972  llvm::OwningPtr<union RSType> ST(new union RSType);
973  RS_TYPE_SET_CLASS(ST, RS_TC_Matrix);
974  switch (getDim()) {
975    case 2: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix2x2); break;
976    case 3: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix3x3); break;
977    case 4: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix4x4); break;
978    default: assert(false && "Matrix type with unsupported dimension.");
979  }
980  return ST.take();
981}
982
983bool RSExportMatrixType::equals(const RSExportable *E) const {
984  CHECK_PARENT_EQUALITY(RSExportType, E);
985  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
986}
987
988/************************* RSExportConstantArrayType *************************/
989RSExportConstantArrayType
990*RSExportConstantArrayType::Create(RSContext *Context,
991                                   const clang::ConstantArrayType *CAT) {
992  assert(CAT != NULL && CAT->getTypeClass() == clang::Type::ConstantArray);
993
994  assert((CAT->getSize().getActiveBits() < 32) && "array too large");
995
996  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
997  assert((Size > 0) && "Constant array should have size greater than 0");
998
999  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
1000  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
1001
1002  if (ElementET == NULL) {
1003    fprintf(stderr, "RSExportConstantArrayType::Create : failed to create "
1004                    "RSExportType for array element.\n");
1005    return NULL;
1006  }
1007
1008  return new RSExportConstantArrayType(Context,
1009                                       ElementET,
1010                                       Size);
1011}
1012
1013const llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1014  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
1015}
1016
1017union RSType *RSExportConstantArrayType::convertToSpecType() const {
1018  llvm::OwningPtr<union RSType> ST(new union RSType);
1019
1020  RS_TYPE_SET_CLASS(ST, RS_TC_ConstantArray);
1021  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_TYPE(
1022      ST, getElementType()->getSpecType());
1023  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_SIZE(ST, getSize());
1024
1025  if (RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(ST) != NULL)
1026    return ST.take();
1027  else
1028    return NULL;
1029}
1030
1031bool RSExportConstantArrayType::keep() {
1032  if (!RSExportType::keep())
1033    return false;
1034  const_cast<RSExportType*>(mElementType)->keep();
1035  return true;
1036}
1037
1038bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1039  CHECK_PARENT_EQUALITY(RSExportType, E);
1040  const RSExportConstantArrayType *RHS =
1041      static_cast<const RSExportConstantArrayType*>(E);
1042  return ((getSize() == RHS->getSize()) &&
1043          (getElementType()->equals(RHS->getElementType())));
1044}
1045
1046/**************************** RSExportRecordType ****************************/
1047RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1048                                               const clang::RecordType *RT,
1049                                               const llvm::StringRef &TypeName,
1050                                               bool mIsArtificial) {
1051  assert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
1052
1053  const clang::RecordDecl *RD = RT->getDecl();
1054  assert(RD->isStruct());
1055
1056  RD = RD->getDefinition();
1057  if (RD == NULL) {
1058    // TODO(zonr): warn that actual struct definition isn't declared in this
1059    //             moudle.
1060    fprintf(stderr, "RSExportRecordType::Create : this struct is not defined "
1061                    "in this module.");
1062    return NULL;
1063  }
1064
1065  // Struct layout construct by clang. We rely on this for obtaining the
1066  // alloc size of a struct and offset of every field in that struct.
1067  const clang::ASTRecordLayout *RL =
1068      &Context->getASTContext().getASTRecordLayout(RD);
1069  assert((RL != NULL) && "Failed to retrieve the struct layout from Clang.");
1070
1071  RSExportRecordType *ERT =
1072      new RSExportRecordType(Context,
1073                             TypeName,
1074                             RD->hasAttr<clang::PackedAttr>(),
1075                             mIsArtificial,
1076                             (RL->getSize() >> 3));
1077  unsigned int Index = 0;
1078
1079  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1080           FE = RD->field_end();
1081       FI != FE;
1082       FI++, Index++) {
1083#define FAILED_CREATE_FIELD(err)    do {         \
1084      if (*err)                                                          \
1085        fprintf(stderr, \
1086                "RSExportRecordType::Create : failed to create field (%s)\n", \
1087                err);                                                   \
1088      delete ERT;                                                       \
1089      return NULL;                                                      \
1090    } while (false)
1091
1092    // FIXME: All fields should be primitive type
1093    assert((*FI)->getKind() == clang::Decl::Field);
1094    clang::FieldDecl *FD = *FI;
1095
1096    // We don't support bit field
1097    //
1098    // TODO(zonr): allow bitfield with size 8, 16, 32
1099    if (FD->isBitField())
1100      FAILED_CREATE_FIELD("bit field is not supported");
1101
1102    // Type
1103    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1104
1105    if (ET != NULL)
1106      ERT->mFields.push_back(
1107          new Field(ET, FD->getName(), ERT,
1108                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1109    else
1110      FAILED_CREATE_FIELD(FD->getName().str().c_str());
1111#undef FAILED_CREATE_FIELD
1112  }
1113
1114  return ERT;
1115}
1116
1117const llvm::Type *RSExportRecordType::convertToLLVMType() const {
1118  // Create an opaque type since struct may reference itself recursively.
1119  llvm::PATypeHolder ResultHolder =
1120      llvm::OpaqueType::get(getRSContext()->getLLVMContext());
1121  setAbstractLLVMType(ResultHolder.get());
1122
1123  std::vector<const llvm::Type*> FieldTypes;
1124
1125  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1126       FI != FE;
1127       FI++) {
1128    const Field *F = *FI;
1129    const RSExportType *FET = F->getType();
1130
1131    FieldTypes.push_back(FET->getLLVMType());
1132  }
1133
1134  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1135                                               FieldTypes,
1136                                               mIsPacked);
1137  if (ST != NULL)
1138    static_cast<llvm::OpaqueType*>(ResultHolder.get())
1139        ->refineAbstractTypeTo(ST);
1140  else
1141    return NULL;
1142  return ResultHolder.get();
1143}
1144
1145union RSType *RSExportRecordType::convertToSpecType() const {
1146  unsigned NumFields = getFields().size();
1147  unsigned AllocSize = sizeof(union RSType) +
1148                       sizeof(struct RSRecordField) * NumFields;
1149  llvm::OwningPtr<union RSType> ST(
1150      reinterpret_cast<union RSType*>(operator new(AllocSize)));
1151
1152  ::memset(ST.get(), 0, AllocSize);
1153
1154  RS_TYPE_SET_CLASS(ST, RS_TC_Record);
1155  RS_RECORD_TYPE_SET_NAME(ST, getName().c_str());
1156  RS_RECORD_TYPE_SET_NUM_FIELDS(ST, NumFields);
1157
1158  setSpecTypeTemporarily(ST.get());
1159
1160  unsigned FieldIdx = 0;
1161  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1162       FI != FE;
1163       FI++, FieldIdx++) {
1164    const Field *F = *FI;
1165
1166    RS_RECORD_TYPE_SET_FIELD_NAME(ST, FieldIdx, F->getName().c_str());
1167    RS_RECORD_TYPE_SET_FIELD_TYPE(ST, FieldIdx, F->getType()->getSpecType());
1168
1169    enum RSDataKind DK = RS_DK_User;
1170    if ((F->getType()->getClass() == ExportClassPrimitive) ||
1171        (F->getType()->getClass() == ExportClassVector)) {
1172      const RSExportPrimitiveType *EPT =
1173        static_cast<const RSExportPrimitiveType*>(F->getType());
1174      // enum RSExportPrimitiveType::DataKind is synced with enum RSDataKind in
1175      // slang_rs_type_spec.h
1176      DK = static_cast<enum RSDataKind>(EPT->getKind());
1177    }
1178    RS_RECORD_TYPE_SET_FIELD_DATA_KIND(ST, FieldIdx, DK);
1179  }
1180
1181  // TODO(slang): Check whether all fields were created normally.
1182
1183  return ST.take();
1184}
1185
1186bool RSExportRecordType::keep() {
1187  if (!RSExportType::keep())
1188    return false;
1189  for (std::list<const Field*>::iterator I = mFields.begin(),
1190          E = mFields.end();
1191       I != E;
1192       I++) {
1193    const_cast<RSExportType*>((*I)->getType())->keep();
1194  }
1195  return true;
1196}
1197
1198bool RSExportRecordType::equals(const RSExportable *E) const {
1199  CHECK_PARENT_EQUALITY(RSExportType, E);
1200
1201  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1202
1203  if (ERT->getFields().size() != getFields().size())
1204    return false;
1205
1206  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1207
1208  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1209    if (!(*AI)->getType()->equals((*BI)->getType()))
1210      return false;
1211    AI++;
1212    BI++;
1213  }
1214
1215  return true;
1216}
1217
1218}  // namespace slang
1219