slang_rs_export_type.cpp revision feaca06fcb0772e9e972a0d61b17259fc5124d50
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/RecordLayout.h"
23
24#include "llvm/ADT/StringExtras.h"
25
26#include "llvm/DerivedTypes.h"
27
28#include "llvm/Target/TargetData.h"
29
30#include "llvm/Type.h"
31
32#include "slang_rs_context.h"
33#include "slang_rs_export_element.h"
34#include "slang_rs_type_spec.h"
35
36#define CHECK_PARENT_EQUALITY(ParentClass, E) \
37  if (!ParentClass::equals(E))                \
38    return false;
39
40namespace slang {
41
42namespace {
43
44static const clang::Type *TypeExportableHelper(
45    const clang::Type *T,
46    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
47    clang::Diagnostic *Diags,
48    clang::SourceManager *SM,
49    const clang::VarDecl *VD,
50    const clang::RecordDecl *TopLevelRecord);
51
52static void ReportTypeError(clang::Diagnostic *Diags,
53                            const clang::SourceManager *SM,
54                            const clang::VarDecl *VD,
55                            const clang::RecordDecl *TopLevelRecord,
56                            const char *Message) {
57  if (!Diags || !SM) {
58    return;
59  }
60
61  // Attempt to use the type declaration first (if we have one).
62  // Fall back to the variable definition, if we are looking at something
63  // like an array declaration that can't be exported.
64  if (TopLevelRecord) {
65    Diags->Report(clang::FullSourceLoc(TopLevelRecord->getLocation(), *SM),
66                  Diags->getCustomDiagID(clang::Diagnostic::Error, Message))
67         << TopLevelRecord->getName();
68  } else if (VD) {
69    Diags->Report(clang::FullSourceLoc(VD->getLocation(), *SM),
70                  Diags->getCustomDiagID(clang::Diagnostic::Error, Message))
71         << VD->getName();
72  } else {
73    assert(false && "Variables should be validated before exporting");
74  }
75
76  return;
77}
78
79static const clang::Type *ConstantArrayTypeExportableHelper(
80    const clang::ConstantArrayType *CAT,
81    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
82    clang::Diagnostic *Diags,
83    clang::SourceManager *SM,
84    const clang::VarDecl *VD,
85    const clang::RecordDecl *TopLevelRecord) {
86  // Check element type
87  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
88  if (ElementType->isArrayType()) {
89    ReportTypeError(Diags, SM, VD, TopLevelRecord,
90        "multidimensional arrays cannot be exported: '%0'");
91    return NULL;
92  } else if (ElementType->isExtVectorType()) {
93    const clang::ExtVectorType *EVT =
94        static_cast<const clang::ExtVectorType*>(ElementType);
95    unsigned numElements = EVT->getNumElements();
96
97    const clang::Type *BaseElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
98    if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
99      ReportTypeError(Diags, SM, VD, TopLevelRecord,
100          "vectors of non-primitive types cannot be exported: '%0'");
101      return NULL;
102    }
103
104    if (numElements == 3 && CAT->getSize() != 1) {
105      ReportTypeError(Diags, SM, VD, TopLevelRecord,
106          "arrays of width 3 vector types cannot be exported: '%0'");
107      return NULL;
108    }
109  }
110
111  if (TypeExportableHelper(ElementType, SPS, Diags, SM, VD,
112                           TopLevelRecord) == NULL)
113    return NULL;
114  else
115    return CAT;
116}
117
118static const clang::Type *TypeExportableHelper(
119    const clang::Type *T,
120    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
121    clang::Diagnostic *Diags,
122    clang::SourceManager *SM,
123    const clang::VarDecl *VD,
124    const clang::RecordDecl *TopLevelRecord) {
125  // Normalize first
126  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
127    return NULL;
128
129  if (SPS.count(T))
130    return T;
131
132  switch (T->getTypeClass()) {
133    case clang::Type::Builtin: {
134      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
135
136      switch (BT->getKind()) {
137#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
138        case builtin_type:
139#include "RSClangBuiltinEnums.inc"
140          return T;
141        default: {
142          return NULL;
143        }
144      }
145    }
146    case clang::Type::Record: {
147      if (RSExportPrimitiveType::GetRSSpecificType(T) !=
148          RSExportPrimitiveType::DataTypeUnknown)
149        return T;  // RS object type, no further checks are needed
150
151      // Check internal struct
152      if (T->isUnionType()) {
153        ReportTypeError(Diags, SM, NULL, T->getAsUnionType()->getDecl(),
154            "unions cannot be exported: '%0'");
155        return NULL;
156      } else if (!T->isStructureType()) {
157        assert(false && "Unknown type cannot be exported");
158        return NULL;
159      }
160
161      clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
162      if (RD != NULL) {
163        RD = RD->getDefinition();
164        if (RD == NULL) {
165          ReportTypeError(Diags, SM, NULL, T->getAsStructureType()->getDecl(),
166              "struct is not defined in this module");
167          return NULL;
168        }
169      }
170
171      if (!TopLevelRecord) {
172        TopLevelRecord = RD;
173      }
174      if (RD->getName().empty()) {
175        ReportTypeError(Diags, SM, NULL, RD,
176            "anonymous structures cannot be exported");
177        return NULL;
178      }
179
180      // Fast check
181      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
182        return NULL;
183
184      // Insert myself into checking set
185      SPS.insert(T);
186
187      // Check all element
188      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
189               FE = RD->field_end();
190           FI != FE;
191           FI++) {
192        const clang::FieldDecl *FD = *FI;
193        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
194        FT = GET_CANONICAL_TYPE(FT);
195
196        if (!TypeExportableHelper(FT, SPS, Diags, SM, VD, TopLevelRecord)) {
197          return NULL;
198        }
199
200        // We don't support bit fields yet
201        //
202        // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
203        if (FD->isBitField()) {
204          if (Diags && SM) {
205            Diags->Report(clang::FullSourceLoc(FD->getLocation(), *SM),
206                          Diags->getCustomDiagID(clang::Diagnostic::Error,
207                          "bit fields are not able to be exported: '%0.%1'"))
208                << RD->getName()
209                << FD->getName();
210          }
211          return NULL;
212        }
213      }
214
215      return T;
216    }
217    case clang::Type::Pointer: {
218      if (TopLevelRecord) {
219        ReportTypeError(Diags, SM, NULL, TopLevelRecord,
220            "structures containing pointers cannot be exported: '%0'");
221        return NULL;
222      }
223
224      const clang::PointerType *PT = UNSAFE_CAST_TYPE(clang::PointerType, T);
225      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
226
227      if (PointeeType->getTypeClass() == clang::Type::Pointer)
228        return T;
229      // We don't support pointer with array-type pointee or unsupported pointee
230      // type
231      if (PointeeType->isArrayType() ||
232          (TypeExportableHelper(PointeeType, SPS, Diags, SM, VD,
233                                TopLevelRecord) == NULL))
234        return NULL;
235      else
236        return T;
237    }
238    case clang::Type::ExtVector: {
239      const clang::ExtVectorType *EVT =
240          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
241      // Only vector with size 2, 3 and 4 are supported.
242      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
243        return NULL;
244
245      // Check base element type
246      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
247
248      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
249          (TypeExportableHelper(ElementType, SPS, Diags, SM, VD,
250                                TopLevelRecord) == NULL))
251        return NULL;
252      else
253        return T;
254    }
255    case clang::Type::ConstantArray: {
256      const clang::ConstantArrayType *CAT =
257          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
258
259      return ConstantArrayTypeExportableHelper(CAT, SPS, Diags, SM, VD,
260                                               TopLevelRecord);
261    }
262    default: {
263      return NULL;
264    }
265  }
266}
267
268// Return the type that can be used to create RSExportType, will always return
269// the canonical type
270// If the Type T is not exportable, this function returns NULL. Diags and SM
271// are used to generate proper Clang diagnostic messages when a
272// non-exportable type is detected. TopLevelRecord is used to capture the
273// highest struct (in the case of a nested hierarchy) for detecting other
274// types that cannot be exported (mostly pointers within a struct).
275static const clang::Type *TypeExportable(const clang::Type *T,
276                                         clang::Diagnostic *Diags,
277                                         clang::SourceManager *SM,
278                                         const clang::VarDecl *VD) {
279  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
280      llvm::SmallPtrSet<const clang::Type*, 8>();
281
282  return TypeExportableHelper(T, SPS, Diags, SM, VD, NULL);
283}
284
285}  // namespace
286
287/****************************** RSExportType ******************************/
288bool RSExportType::NormalizeType(const clang::Type *&T,
289                                 llvm::StringRef &TypeName,
290                                 clang::Diagnostic *Diags,
291                                 clang::SourceManager *SM,
292                                 const clang::VarDecl *VD) {
293  if ((T = TypeExportable(T, Diags, SM, VD)) == NULL) {
294    return false;
295  }
296  // Get type name
297  TypeName = RSExportType::GetTypeName(T);
298  if (TypeName.empty()) {
299    if (Diags && SM) {
300      if (VD) {
301        Diags->Report(clang::FullSourceLoc(VD->getLocation(), *SM),
302                      Diags->getCustomDiagID(clang::Diagnostic::Error,
303                                             "anonymous types cannot "
304                                             "be exported"));
305      } else {
306        Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
307                                             "anonymous types cannot "
308                                             "be exported"));
309      }
310    }
311    return false;
312  }
313
314  return true;
315}
316
317const clang::Type
318*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
319  if (DD) {
320    clang::QualType T;
321    if (DD->getTypeSourceInfo())
322      T = DD->getTypeSourceInfo()->getType();
323    else
324      T = DD->getType();
325
326    if (T.isNull())
327      return NULL;
328    else
329      return T.getTypePtr();
330  }
331  return NULL;
332}
333
334llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
335  T = GET_CANONICAL_TYPE(T);
336  if (T == NULL)
337    return llvm::StringRef();
338
339  switch (T->getTypeClass()) {
340    case clang::Type::Builtin: {
341      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
342
343      switch (BT->getKind()) {
344#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
345        case builtin_type:                                    \
346          return cname;                                       \
347        break;
348#include "RSClangBuiltinEnums.inc"
349        default: {
350          assert(false && "Unknown data type of the builtin");
351          break;
352        }
353      }
354      break;
355    }
356    case clang::Type::Record: {
357      clang::RecordDecl *RD;
358      if (T->isStructureType()) {
359        RD = T->getAsStructureType()->getDecl();
360      } else {
361        break;
362      }
363
364      llvm::StringRef Name = RD->getName();
365      if (Name.empty()) {
366          if (RD->getTypedefForAnonDecl() != NULL)
367            Name = RD->getTypedefForAnonDecl()->getName();
368
369          if (Name.empty())
370            // Try to find a name from redeclaration (i.e. typedef)
371            for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
372                     RE = RD->redecls_end();
373                 RI != RE;
374                 RI++) {
375              assert(*RI != NULL && "cannot be NULL object");
376
377              Name = (*RI)->getName();
378              if (!Name.empty())
379                break;
380            }
381      }
382      return Name;
383    }
384    case clang::Type::Pointer: {
385      // "*" plus pointee name
386      const clang::Type *PT = GET_POINTEE_TYPE(T);
387      llvm::StringRef PointeeName;
388      if (NormalizeType(PT, PointeeName, NULL, NULL, NULL)) {
389        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
390        Name[0] = '*';
391        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
392        Name[PointeeName.size() + 1] = '\0';
393        return Name;
394      }
395      break;
396    }
397    case clang::Type::ExtVector: {
398      const clang::ExtVectorType *EVT =
399          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
400      return RSExportVectorType::GetTypeName(EVT);
401      break;
402    }
403    case clang::Type::ConstantArray : {
404      // Construct name for a constant array is too complicated.
405      return DUMMY_TYPE_NAME_FOR_RS_CONSTANT_ARRAY_TYPE;
406    }
407    default: {
408      break;
409    }
410  }
411
412  return llvm::StringRef();
413}
414
415
416RSExportType *RSExportType::Create(RSContext *Context,
417                                   const clang::Type *T,
418                                   const llvm::StringRef &TypeName) {
419  // Lookup the context to see whether the type was processed before.
420  // Newly created RSExportType will insert into context
421  // in RSExportType::RSExportType()
422  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
423
424  if (ETI != Context->export_types_end())
425    return ETI->second;
426
427  RSExportType *ET = NULL;
428  switch (T->getTypeClass()) {
429    case clang::Type::Record: {
430      RSExportPrimitiveType::DataType dt =
431          RSExportPrimitiveType::GetRSSpecificType(TypeName);
432      switch (dt) {
433        case RSExportPrimitiveType::DataTypeUnknown: {
434          // User-defined types
435          ET = RSExportRecordType::Create(Context,
436                                          T->getAsStructureType(),
437                                          TypeName);
438          break;
439        }
440        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
441          // 2 x 2 Matrix type
442          ET = RSExportMatrixType::Create(Context,
443                                          T->getAsStructureType(),
444                                          TypeName,
445                                          2);
446          break;
447        }
448        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
449          // 3 x 3 Matrix type
450          ET = RSExportMatrixType::Create(Context,
451                                          T->getAsStructureType(),
452                                          TypeName,
453                                          3);
454          break;
455        }
456        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
457          // 4 x 4 Matrix type
458          ET = RSExportMatrixType::Create(Context,
459                                          T->getAsStructureType(),
460                                          TypeName,
461                                          4);
462          break;
463        }
464        default: {
465          // Others are primitive types
466          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
467          break;
468        }
469      }
470      break;
471    }
472    case clang::Type::Builtin: {
473      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
474      break;
475    }
476    case clang::Type::Pointer: {
477      ET = RSExportPointerType::Create(Context,
478                                       UNSAFE_CAST_TYPE(clang::PointerType, T),
479                                       TypeName);
480      // FIXME: free the name (allocated in RSExportType::GetTypeName)
481      delete [] TypeName.data();
482      break;
483    }
484    case clang::Type::ExtVector: {
485      ET = RSExportVectorType::Create(Context,
486                                      UNSAFE_CAST_TYPE(clang::ExtVectorType, T),
487                                      TypeName);
488      break;
489    }
490    case clang::Type::ConstantArray: {
491      ET = RSExportConstantArrayType::Create(
492              Context,
493              UNSAFE_CAST_TYPE(clang::ConstantArrayType, T));
494      break;
495    }
496    default: {
497      clang::Diagnostic *Diags = Context->getDiagnostics();
498      Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
499                        "unknown type cannot be exported: '%0'"))
500          << T->getTypeClassName();
501      break;
502    }
503  }
504
505  return ET;
506}
507
508RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
509  llvm::StringRef TypeName;
510  if (NormalizeType(T, TypeName, NULL, NULL, NULL))
511    return Create(Context, T, TypeName);
512  else
513    return NULL;
514}
515
516RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
517                                           const clang::VarDecl *VD) {
518  return RSExportType::Create(Context, GetTypeOfDecl(VD));
519}
520
521size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
522  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
523      ET->getLLVMType());
524}
525
526size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
527  if (ET->getClass() == RSExportType::ExportClassRecord)
528    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
529  else
530    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
531        ET->getLLVMType());
532}
533
534RSExportType::RSExportType(RSContext *Context,
535                           ExportClass Class,
536                           const llvm::StringRef &Name)
537    : RSExportable(Context, RSExportable::EX_TYPE),
538      mClass(Class),
539      // Make a copy on Name since memory stored @Name is either allocated in
540      // ASTContext or allocated in GetTypeName which will be destroyed later.
541      mName(Name.data(), Name.size()),
542      mLLVMType(NULL),
543      mSpecType(NULL) {
544  // Don't cache the type whose name start with '<'. Those type failed to
545  // get their name since constructing their name in GetTypeName() requiring
546  // complicated work.
547  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
548    // TODO(zonr): Need to check whether the insertion is successful or not.
549    Context->insertExportType(llvm::StringRef(Name), this);
550  return;
551}
552
553bool RSExportType::keep() {
554  if (!RSExportable::keep())
555    return false;
556  // Invalidate converted LLVM type.
557  mLLVMType = NULL;
558  return true;
559}
560
561bool RSExportType::equals(const RSExportable *E) const {
562  CHECK_PARENT_EQUALITY(RSExportable, E);
563  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
564}
565
566RSExportType::~RSExportType() {
567  delete mSpecType;
568}
569
570/************************** RSExportPrimitiveType **************************/
571llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
572RSExportPrimitiveType::RSSpecificTypeMap;
573
574llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
575
576bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
577  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
578    return true;
579  else
580    return false;
581}
582
583RSExportPrimitiveType::DataType
584RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
585  if (TypeName.empty())
586    return DataTypeUnknown;
587
588  if (RSSpecificTypeMap->empty()) {
589#define ENUM_RS_MATRIX_TYPE(type, cname, dim)                       \
590    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
591#include "RSMatrixTypeEnums.inc"
592#define ENUM_RS_OBJECT_TYPE(type, cname)                            \
593    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
594#include "RSObjectTypeEnums.inc"
595  }
596
597  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
598  if (I == RSSpecificTypeMap->end())
599    return DataTypeUnknown;
600  else
601    return I->getValue();
602}
603
604RSExportPrimitiveType::DataType
605RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
606  T = GET_CANONICAL_TYPE(T);
607  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
608    return DataTypeUnknown;
609
610  return GetRSSpecificType( RSExportType::GetTypeName(T) );
611}
612
613bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
614  return ((DT >= FirstRSMatrixType) && (DT <= LastRSMatrixType));
615}
616
617bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
618  return ((DT >= FirstRSObjectType) && (DT <= LastRSObjectType));
619}
620
621bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
622  bool RSObjectTypeSeen = false;
623  while (T && T->isArrayType()) {
624    T = T->getArrayElementTypeNoTypeQual();
625  }
626
627  const clang::RecordType *RT = T->getAsStructureType();
628  if (!RT) {
629    return false;
630  }
631  const clang::RecordDecl *RD = RT->getDecl();
632  RD = RD->getDefinition();
633  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
634         FE = RD->field_end();
635       FI != FE;
636       FI++) {
637    // We just look through all field declarations to see if we find a
638    // declaration for an RS object type (or an array of one).
639    const clang::FieldDecl *FD = *FI;
640    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
641    while (FT && FT->isArrayType()) {
642      FT = FT->getArrayElementTypeNoTypeQual();
643    }
644
645    RSExportPrimitiveType::DataType DT = GetRSSpecificType(FT);
646    if (IsRSObjectType(DT)) {
647      // RS object types definitely need to be zero-initialized
648      RSObjectTypeSeen = true;
649    } else {
650      switch (DT) {
651        case RSExportPrimitiveType::DataTypeRSMatrix2x2:
652        case RSExportPrimitiveType::DataTypeRSMatrix3x3:
653        case RSExportPrimitiveType::DataTypeRSMatrix4x4:
654          // Matrix types should get zero-initialized as well
655          RSObjectTypeSeen = true;
656          break;
657        default:
658          // Ignore all other primitive types
659          break;
660      }
661      while (FT && FT->isArrayType()) {
662        FT = FT->getArrayElementTypeNoTypeQual();
663      }
664      if (FT->isStructureType()) {
665        // Recursively handle structs of structs (even though these can't
666        // be exported, it is possible for a user to have them internally).
667        RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
668      }
669    }
670  }
671
672  return RSObjectTypeSeen;
673}
674
675const size_t RSExportPrimitiveType::SizeOfDataTypeInBits[] = {
676#define ENUM_RS_DATA_TYPE(type, cname, bits)  \
677  bits,
678#include "RSDataTypeEnums.inc"
679  0   // DataTypeMax
680};
681
682size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
683  assert(((EPT->getType() > DataTypeUnknown) &&
684          (EPT->getType() < DataTypeMax)) &&
685         "RSExportPrimitiveType::GetSizeInBits : unknown data type");
686  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
687}
688
689RSExportPrimitiveType::DataType
690RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
691  if (T == NULL)
692    return DataTypeUnknown;
693
694  switch (T->getTypeClass()) {
695    case clang::Type::Builtin: {
696      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
697      switch (BT->getKind()) {
698#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
699        case builtin_type: {                                  \
700          return DataType ## type;                            \
701        }
702#include "RSClangBuiltinEnums.inc"
703        // The size of type WChar depend on platform so we abandon the support
704        // to them.
705        default: {
706          clang::Diagnostic *Diags = Context->getDiagnostics();
707          Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
708                            "built-in type cannot be exported: '%0'"))
709              << T->getTypeClassName();
710          break;
711        }
712      }
713      break;
714    }
715    case clang::Type::Record: {
716      // must be RS object type
717      return RSExportPrimitiveType::GetRSSpecificType(T);
718    }
719    default: {
720      clang::Diagnostic *Diags = Context->getDiagnostics();
721      Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
722                        "primitive type cannot be exported: '%0'"))
723          << T->getTypeClassName();
724      break;
725    }
726  }
727
728  return DataTypeUnknown;
729}
730
731RSExportPrimitiveType
732*RSExportPrimitiveType::Create(RSContext *Context,
733                               const clang::Type *T,
734                               const llvm::StringRef &TypeName,
735                               DataKind DK,
736                               bool Normalized) {
737  DataType DT = GetDataType(Context, T);
738
739  if ((DT == DataTypeUnknown) || TypeName.empty())
740    return NULL;
741  else
742    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
743                                     DT, DK, Normalized);
744}
745
746RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
747                                                     const clang::Type *T,
748                                                     DataKind DK) {
749  llvm::StringRef TypeName;
750  if (RSExportType::NormalizeType(T, TypeName, NULL, NULL, NULL) &&
751      IsPrimitiveType(T)) {
752    return Create(Context, T, TypeName, DK);
753  } else {
754    return NULL;
755  }
756}
757
758const llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
759  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
760
761  if (isRSObjectType()) {
762    // struct {
763    //   int *p;
764    // } __attribute__((packed, aligned(pointer_size)))
765    //
766    // which is
767    //
768    // <{ [1 x i32] }> in LLVM
769    //
770    if (RSObjectLLVMType == NULL) {
771      std::vector<const llvm::Type *> Elements;
772      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
773      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
774    }
775    return RSObjectLLVMType;
776  }
777
778  switch (mType) {
779    case DataTypeFloat32: {
780      return llvm::Type::getFloatTy(C);
781      break;
782    }
783    case DataTypeFloat64: {
784      return llvm::Type::getDoubleTy(C);
785      break;
786    }
787    case DataTypeBoolean: {
788      return llvm::Type::getInt1Ty(C);
789      break;
790    }
791    case DataTypeSigned8:
792    case DataTypeUnsigned8: {
793      return llvm::Type::getInt8Ty(C);
794      break;
795    }
796    case DataTypeSigned16:
797    case DataTypeUnsigned16:
798    case DataTypeUnsigned565:
799    case DataTypeUnsigned5551:
800    case DataTypeUnsigned4444: {
801      return llvm::Type::getInt16Ty(C);
802      break;
803    }
804    case DataTypeSigned32:
805    case DataTypeUnsigned32: {
806      return llvm::Type::getInt32Ty(C);
807      break;
808    }
809    case DataTypeSigned64:
810    case DataTypeUnsigned64: {
811      return llvm::Type::getInt64Ty(C);
812      break;
813    }
814    default: {
815      assert(false && "Unknown data type");
816    }
817  }
818
819  return NULL;
820}
821
822union RSType *RSExportPrimitiveType::convertToSpecType() const {
823  llvm::OwningPtr<union RSType> ST(new union RSType);
824  RS_TYPE_SET_CLASS(ST, RS_TC_Primitive);
825  // enum RSExportPrimitiveType::DataType is synced with enum RSDataType in
826  // slang_rs_type_spec.h
827  RS_PRIMITIVE_TYPE_SET_DATA_TYPE(ST, getType());
828  return ST.take();
829}
830
831bool RSExportPrimitiveType::equals(const RSExportable *E) const {
832  CHECK_PARENT_EQUALITY(RSExportType, E);
833  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
834}
835
836/**************************** RSExportPointerType ****************************/
837
838RSExportPointerType
839*RSExportPointerType::Create(RSContext *Context,
840                             const clang::PointerType *PT,
841                             const llvm::StringRef &TypeName) {
842  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
843  const RSExportType *PointeeET;
844
845  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
846    PointeeET = RSExportType::Create(Context, PointeeType);
847  } else {
848    // Double or higher dimension of pointer, export as int*
849    PointeeET = RSExportPrimitiveType::Create(Context,
850                    Context->getASTContext().IntTy.getTypePtr());
851  }
852
853  if (PointeeET == NULL) {
854    // Error diagnostic is emitted for corresponding pointee type
855    return NULL;
856  }
857
858  return new RSExportPointerType(Context, TypeName, PointeeET);
859}
860
861const llvm::Type *RSExportPointerType::convertToLLVMType() const {
862  const llvm::Type *PointeeType = mPointeeType->getLLVMType();
863  return llvm::PointerType::getUnqual(PointeeType);
864}
865
866union RSType *RSExportPointerType::convertToSpecType() const {
867  llvm::OwningPtr<union RSType> ST(new union RSType);
868
869  RS_TYPE_SET_CLASS(ST, RS_TC_Pointer);
870  RS_POINTER_TYPE_SET_POINTEE_TYPE(ST, getPointeeType()->getSpecType());
871
872  if (RS_POINTER_TYPE_GET_POINTEE_TYPE(ST) != NULL)
873    return ST.take();
874  else
875    return NULL;
876}
877
878bool RSExportPointerType::keep() {
879  if (!RSExportType::keep())
880    return false;
881  const_cast<RSExportType*>(mPointeeType)->keep();
882  return true;
883}
884
885bool RSExportPointerType::equals(const RSExportable *E) const {
886  CHECK_PARENT_EQUALITY(RSExportType, E);
887  return (static_cast<const RSExportPointerType*>(E)
888              ->getPointeeType()->equals(getPointeeType()));
889}
890
891/***************************** RSExportVectorType *****************************/
892llvm::StringRef
893RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
894  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
895
896  if ((ElementType->getTypeClass() != clang::Type::Builtin))
897    return llvm::StringRef();
898
899  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType,
900                                                  ElementType);
901  if ((EVT->getNumElements() < 1) ||
902      (EVT->getNumElements() > 4))
903    return llvm::StringRef();
904
905  switch (BT->getKind()) {
906    // Compiler is smart enough to optimize following *big if branches* since
907    // they all become "constant comparison" after macro expansion
908#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
909    case builtin_type: {                                      \
910      const char *Name[] = { cname"2", cname"3", cname"4" };  \
911      return Name[EVT->getNumElements() - 2];                 \
912      break;                                                  \
913    }
914#include "RSClangBuiltinEnums.inc"
915    default: {
916      return llvm::StringRef();
917    }
918  }
919}
920
921RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
922                                               const clang::ExtVectorType *EVT,
923                                               const llvm::StringRef &TypeName,
924                                               DataKind DK,
925                                               bool Normalized) {
926  assert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
927
928  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
929  RSExportPrimitiveType::DataType DT =
930      RSExportPrimitiveType::GetDataType(Context, ElementType);
931
932  if (DT != RSExportPrimitiveType::DataTypeUnknown)
933    return new RSExportVectorType(Context,
934                                  TypeName,
935                                  DT,
936                                  DK,
937                                  Normalized,
938                                  EVT->getNumElements());
939  else
940    return NULL;
941}
942
943const llvm::Type *RSExportVectorType::convertToLLVMType() const {
944  const llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
945  return llvm::VectorType::get(ElementType, getNumElement());
946}
947
948union RSType *RSExportVectorType::convertToSpecType() const {
949  llvm::OwningPtr<union RSType> ST(new union RSType);
950
951  RS_TYPE_SET_CLASS(ST, RS_TC_Vector);
952  RS_VECTOR_TYPE_SET_ELEMENT_TYPE(ST, getType());
953  RS_VECTOR_TYPE_SET_VECTOR_SIZE(ST, getNumElement());
954
955  return ST.take();
956}
957
958bool RSExportVectorType::equals(const RSExportable *E) const {
959  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
960  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
961              == getNumElement());
962}
963
964/***************************** RSExportMatrixType *****************************/
965RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
966                                               const clang::RecordType *RT,
967                                               const llvm::StringRef &TypeName,
968                                               unsigned Dim) {
969  assert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
970  assert((Dim > 1) && "Invalid dimension of matrix");
971
972  // Check whether the struct rs_matrix is in our expected form (but assume it's
973  // correct if we're not sure whether it's correct or not)
974  const clang::RecordDecl* RD = RT->getDecl();
975  RD = RD->getDefinition();
976  if (RD != NULL) {
977    clang::Diagnostic *Diags = Context->getDiagnostics();
978    const clang::SourceManager *SM = Context->getSourceManager();
979    // Find definition, perform further examination
980    if (RD->field_empty()) {
981      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
982                    Diags->getCustomDiagID(clang::Diagnostic::Error,
983                        "invalid matrix struct: must have 1 field for saving "
984                        "values: '%0'"))
985           << RD->getName();
986      return NULL;
987    }
988
989    clang::RecordDecl::field_iterator FIT = RD->field_begin();
990    const clang::FieldDecl *FD = *FIT;
991    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
992    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
993      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
994                    Diags->getCustomDiagID(clang::Diagnostic::Error,
995                        "invalid matrix struct: first field should be an "
996                        "array with constant size: '%0'"))
997           << RD->getName();
998      return NULL;
999    }
1000    const clang::ConstantArrayType *CAT =
1001      static_cast<const clang::ConstantArrayType *>(FT);
1002    const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
1003    if ((ElementType == NULL) ||
1004        (ElementType->getTypeClass() != clang::Type::Builtin) ||
1005        (static_cast<const clang::BuiltinType *>(ElementType)->getKind()
1006          != clang::BuiltinType::Float)) {
1007      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
1008                    Diags->getCustomDiagID(clang::Diagnostic::Error,
1009                        "invalid matrix struct: first field should be a "
1010                        "float array: '%0'"))
1011           << RD->getName();
1012      return NULL;
1013    }
1014
1015    if (CAT->getSize() != Dim * Dim) {
1016      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
1017                    Diags->getCustomDiagID(clang::Diagnostic::Error,
1018                        "invalid matrix struct: first field should be an "
1019                        "array with size %0: '%1'"))
1020           << Dim * Dim
1021           << RD->getName();
1022      return NULL;
1023    }
1024
1025    FIT++;
1026    if (FIT != RD->field_end()) {
1027      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
1028                    Diags->getCustomDiagID(clang::Diagnostic::Error,
1029                        "invalid matrix struct: must have exactly 1 field: "
1030                        "'%0'"))
1031           << RD->getName();
1032      return NULL;
1033    }
1034  }
1035
1036  return new RSExportMatrixType(Context, TypeName, Dim);
1037}
1038
1039const llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1040  // Construct LLVM type:
1041  // struct {
1042  //  float X[mDim * mDim];
1043  // }
1044
1045  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1046  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1047                                            mDim * mDim);
1048  return llvm::StructType::get(C, X, NULL);
1049}
1050
1051union RSType *RSExportMatrixType::convertToSpecType() const {
1052  llvm::OwningPtr<union RSType> ST(new union RSType);
1053  RS_TYPE_SET_CLASS(ST, RS_TC_Matrix);
1054  switch (getDim()) {
1055    case 2: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix2x2); break;
1056    case 3: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix3x3); break;
1057    case 4: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix4x4); break;
1058    default: assert(false && "Matrix type with unsupported dimension.");
1059  }
1060  return ST.take();
1061}
1062
1063bool RSExportMatrixType::equals(const RSExportable *E) const {
1064  CHECK_PARENT_EQUALITY(RSExportType, E);
1065  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1066}
1067
1068/************************* RSExportConstantArrayType *************************/
1069RSExportConstantArrayType
1070*RSExportConstantArrayType::Create(RSContext *Context,
1071                                   const clang::ConstantArrayType *CAT) {
1072  assert(CAT != NULL && CAT->getTypeClass() == clang::Type::ConstantArray);
1073
1074  assert((CAT->getSize().getActiveBits() < 32) && "array too large");
1075
1076  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1077  assert((Size > 0) && "Constant array should have size greater than 0");
1078
1079  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
1080  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
1081
1082  if (ElementET == NULL) {
1083    return NULL;
1084  }
1085
1086  return new RSExportConstantArrayType(Context,
1087                                       ElementET,
1088                                       Size);
1089}
1090
1091const llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1092  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
1093}
1094
1095union RSType *RSExportConstantArrayType::convertToSpecType() const {
1096  llvm::OwningPtr<union RSType> ST(new union RSType);
1097
1098  RS_TYPE_SET_CLASS(ST, RS_TC_ConstantArray);
1099  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_TYPE(
1100      ST, getElementType()->getSpecType());
1101  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_SIZE(ST, getSize());
1102
1103  if (RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(ST) != NULL)
1104    return ST.take();
1105  else
1106    return NULL;
1107}
1108
1109bool RSExportConstantArrayType::keep() {
1110  if (!RSExportType::keep())
1111    return false;
1112  const_cast<RSExportType*>(mElementType)->keep();
1113  return true;
1114}
1115
1116bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1117  CHECK_PARENT_EQUALITY(RSExportType, E);
1118  const RSExportConstantArrayType *RHS =
1119      static_cast<const RSExportConstantArrayType*>(E);
1120  return ((getSize() == RHS->getSize()) &&
1121          (getElementType()->equals(RHS->getElementType())));
1122}
1123
1124/**************************** RSExportRecordType ****************************/
1125RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1126                                               const clang::RecordType *RT,
1127                                               const llvm::StringRef &TypeName,
1128                                               bool mIsArtificial) {
1129  assert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
1130
1131  const clang::RecordDecl *RD = RT->getDecl();
1132  assert(RD->isStruct());
1133
1134  RD = RD->getDefinition();
1135  if (RD == NULL) {
1136    assert(false && "struct is not defined in this module");
1137    return NULL;
1138  }
1139
1140  // Struct layout construct by clang. We rely on this for obtaining the
1141  // alloc size of a struct and offset of every field in that struct.
1142  const clang::ASTRecordLayout *RL =
1143      &Context->getASTContext().getASTRecordLayout(RD);
1144  assert((RL != NULL) && "Failed to retrieve the struct layout from Clang.");
1145
1146  RSExportRecordType *ERT =
1147      new RSExportRecordType(Context,
1148                             TypeName,
1149                             RD->hasAttr<clang::PackedAttr>(),
1150                             mIsArtificial,
1151                             (RL->getSize() >> 3));
1152  unsigned int Index = 0;
1153
1154  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1155           FE = RD->field_end();
1156       FI != FE;
1157       FI++, Index++) {
1158    clang::Diagnostic *Diags = Context->getDiagnostics();
1159    const clang::SourceManager *SM = Context->getSourceManager();
1160
1161    // FIXME: All fields should be primitive type
1162    assert((*FI)->getKind() == clang::Decl::Field);
1163    clang::FieldDecl *FD = *FI;
1164
1165    if (FD->isBitField()) {
1166      delete ERT;
1167      return NULL;
1168    }
1169
1170    // Type
1171    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1172
1173    if (ET != NULL) {
1174      ERT->mFields.push_back(
1175          new Field(ET, FD->getName(), ERT,
1176                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1177    } else {
1178      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
1179                    Diags->getCustomDiagID(clang::Diagnostic::Error,
1180                    "field type cannot be exported: '%0.%1'"))
1181          << RD->getName()
1182          << FD->getName();
1183      delete ERT;
1184      return NULL;
1185    }
1186  }
1187
1188  return ERT;
1189}
1190
1191const llvm::Type *RSExportRecordType::convertToLLVMType() const {
1192  // Create an opaque type since struct may reference itself recursively.
1193  llvm::PATypeHolder ResultHolder =
1194      llvm::OpaqueType::get(getRSContext()->getLLVMContext());
1195  setAbstractLLVMType(ResultHolder.get());
1196
1197  std::vector<const llvm::Type*> FieldTypes;
1198
1199  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1200       FI != FE;
1201       FI++) {
1202    const Field *F = *FI;
1203    const RSExportType *FET = F->getType();
1204
1205    FieldTypes.push_back(FET->getLLVMType());
1206  }
1207
1208  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1209                                               FieldTypes,
1210                                               mIsPacked);
1211  if (ST != NULL)
1212    static_cast<llvm::OpaqueType*>(ResultHolder.get())
1213        ->refineAbstractTypeTo(ST);
1214  else
1215    return NULL;
1216  return ResultHolder.get();
1217}
1218
1219union RSType *RSExportRecordType::convertToSpecType() const {
1220  unsigned NumFields = getFields().size();
1221  unsigned AllocSize = sizeof(union RSType) +
1222                       sizeof(struct RSRecordField) * NumFields;
1223  llvm::OwningPtr<union RSType> ST(
1224      reinterpret_cast<union RSType*>(operator new(AllocSize)));
1225
1226  ::memset(ST.get(), 0, AllocSize);
1227
1228  RS_TYPE_SET_CLASS(ST, RS_TC_Record);
1229  RS_RECORD_TYPE_SET_NAME(ST, getName().c_str());
1230  RS_RECORD_TYPE_SET_NUM_FIELDS(ST, NumFields);
1231
1232  setSpecTypeTemporarily(ST.get());
1233
1234  unsigned FieldIdx = 0;
1235  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1236       FI != FE;
1237       FI++, FieldIdx++) {
1238    const Field *F = *FI;
1239
1240    RS_RECORD_TYPE_SET_FIELD_NAME(ST, FieldIdx, F->getName().c_str());
1241    RS_RECORD_TYPE_SET_FIELD_TYPE(ST, FieldIdx, F->getType()->getSpecType());
1242
1243    enum RSDataKind DK = RS_DK_User;
1244    if ((F->getType()->getClass() == ExportClassPrimitive) ||
1245        (F->getType()->getClass() == ExportClassVector)) {
1246      const RSExportPrimitiveType *EPT =
1247        static_cast<const RSExportPrimitiveType*>(F->getType());
1248      // enum RSExportPrimitiveType::DataKind is synced with enum RSDataKind in
1249      // slang_rs_type_spec.h
1250      DK = static_cast<enum RSDataKind>(EPT->getKind());
1251    }
1252    RS_RECORD_TYPE_SET_FIELD_DATA_KIND(ST, FieldIdx, DK);
1253  }
1254
1255  // TODO(slang): Check whether all fields were created normally.
1256
1257  return ST.take();
1258}
1259
1260bool RSExportRecordType::keep() {
1261  if (!RSExportType::keep())
1262    return false;
1263  for (std::list<const Field*>::iterator I = mFields.begin(),
1264          E = mFields.end();
1265       I != E;
1266       I++) {
1267    const_cast<RSExportType*>((*I)->getType())->keep();
1268  }
1269  return true;
1270}
1271
1272bool RSExportRecordType::equals(const RSExportable *E) const {
1273  CHECK_PARENT_EQUALITY(RSExportType, E);
1274
1275  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1276
1277  if (ERT->getFields().size() != getFields().size())
1278    return false;
1279
1280  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1281
1282  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1283    if (!(*AI)->getType()->equals((*BI)->getType()))
1284      return false;
1285    AI++;
1286    BI++;
1287  }
1288
1289  return true;
1290}
1291
1292}  // namespace slang
1293