slang_rs_export_type.cpp revision 2ef9bc0cfbca2152d972c0975005f8c897c2a42c
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/RecordLayout.h"
23
24#include "llvm/ADT/StringExtras.h"
25
26#include "llvm/DerivedTypes.h"
27
28#include "llvm/Target/TargetData.h"
29
30#include "llvm/Type.h"
31
32#include "slang_rs_context.h"
33#include "slang_rs_export_element.h"
34#include "slang_rs_type_spec.h"
35
36#define CHECK_PARENT_EQUALITY(ParentClass, E) \
37  if (!ParentClass::equals(E))                \
38    return false;
39
40namespace slang {
41
42namespace {
43
44static const clang::Type *TypeExportableHelper(
45    const clang::Type *T,
46    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
47    clang::Diagnostic *Diags,
48    clang::SourceManager *SM,
49    const clang::VarDecl *VD,
50    const clang::RecordDecl *TopLevelRecord);
51
52static void ReportTypeError(clang::Diagnostic *Diags,
53                            const clang::SourceManager *SM,
54                            const clang::VarDecl *VD,
55                            const clang::RecordDecl *TopLevelRecord,
56                            const char *Message) {
57  if (!Diags || !SM) {
58    return;
59  }
60
61  // Attempt to use the type declaration first (if we have one).
62  // Fall back to the variable definition, if we are looking at something
63  // like an array declaration that can't be exported.
64  if (TopLevelRecord) {
65    Diags->Report(clang::FullSourceLoc(TopLevelRecord->getLocation(), *SM),
66                  Diags->getCustomDiagID(clang::Diagnostic::Error, Message))
67         << TopLevelRecord->getName();
68  } else if (VD) {
69    Diags->Report(clang::FullSourceLoc(VD->getLocation(), *SM),
70                  Diags->getCustomDiagID(clang::Diagnostic::Error, Message))
71         << VD->getName();
72  } else {
73    assert(false && "Variables should be validated before exporting");
74  }
75
76  return;
77}
78
79static const clang::Type *ConstantArrayTypeExportableHelper(
80    const clang::ConstantArrayType *CAT,
81    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
82    clang::Diagnostic *Diags,
83    clang::SourceManager *SM,
84    const clang::VarDecl *VD,
85    const clang::RecordDecl *TopLevelRecord) {
86  // Check element type
87  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
88  if (ElementType->isArrayType()) {
89    ReportTypeError(Diags, SM, VD, TopLevelRecord,
90        "multidimensional arrays cannot be exported: '%0'");
91    return NULL;
92  } else if (ElementType->isExtVectorType()) {
93    const clang::ExtVectorType *EVT =
94        static_cast<const clang::ExtVectorType*>(ElementType);
95    unsigned numElements = EVT->getNumElements();
96
97    const clang::Type *BaseElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
98    if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
99      ReportTypeError(Diags, SM, VD, TopLevelRecord,
100          "vectors of non-primitive types cannot be exported: '%0'");
101      return NULL;
102    }
103
104    if (numElements == 3 && CAT->getSize() != 1) {
105      ReportTypeError(Diags, SM, VD, TopLevelRecord,
106          "arrays of width 3 vector types cannot be exported: '%0'");
107      return NULL;
108    }
109  }
110
111  if (TypeExportableHelper(ElementType, SPS, Diags, SM, VD,
112                           TopLevelRecord) == NULL)
113    return NULL;
114  else
115    return CAT;
116}
117
118static const clang::Type *TypeExportableHelper(
119    const clang::Type *T,
120    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
121    clang::Diagnostic *Diags,
122    clang::SourceManager *SM,
123    const clang::VarDecl *VD,
124    const clang::RecordDecl *TopLevelRecord) {
125  // Normalize first
126  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
127    return NULL;
128
129  if (SPS.count(T))
130    return T;
131
132  switch (T->getTypeClass()) {
133    case clang::Type::Builtin: {
134      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
135
136      switch (BT->getKind()) {
137#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
138        case builtin_type:
139#include "RSClangBuiltinEnums.inc"
140          return T;
141        default: {
142          return NULL;
143        }
144      }
145    }
146    case clang::Type::Record: {
147      if (RSExportPrimitiveType::GetRSSpecificType(T) !=
148          RSExportPrimitiveType::DataTypeUnknown)
149        return T;  // RS object type, no further checks are needed
150
151      // Check internal struct
152      if (T->isUnionType()) {
153        ReportTypeError(Diags, SM, NULL, T->getAsUnionType()->getDecl(),
154            "unions cannot be exported: '%0'");
155        return NULL;
156      } else if (!T->isStructureType()) {
157        assert(false && "Unknown type cannot be exported");
158        return NULL;
159      }
160
161      clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
162      if (RD != NULL) {
163        RD = RD->getDefinition();
164        if (RD == NULL) {
165          ReportTypeError(Diags, SM, NULL, T->getAsStructureType()->getDecl(),
166              "struct is not defined in this module");
167          return NULL;
168        }
169      }
170
171      if (!TopLevelRecord) {
172        TopLevelRecord = RD;
173      }
174      if (RD->getName().empty()) {
175        ReportTypeError(Diags, SM, NULL, RD,
176            "anonymous structures cannot be exported");
177        return NULL;
178      }
179
180      // Fast check
181      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
182        return NULL;
183
184      // Insert myself into checking set
185      SPS.insert(T);
186
187      // Check all element
188      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
189               FE = RD->field_end();
190           FI != FE;
191           FI++) {
192        const clang::FieldDecl *FD = *FI;
193        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
194        FT = GET_CANONICAL_TYPE(FT);
195
196        if (!TypeExportableHelper(FT, SPS, Diags, SM, VD, TopLevelRecord)) {
197          return NULL;
198        }
199
200        // We don't support bit fields yet
201        //
202        // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
203        if (FD->isBitField()) {
204          if (Diags && SM) {
205            Diags->Report(clang::FullSourceLoc(FD->getLocation(), *SM),
206                          Diags->getCustomDiagID(clang::Diagnostic::Error,
207                          "bit fields are not able to be exported: '%0.%1'"))
208                << RD->getName()
209                << FD->getName();
210          }
211          return NULL;
212        }
213      }
214
215      return T;
216    }
217    case clang::Type::Pointer: {
218      if (TopLevelRecord) {
219        ReportTypeError(Diags, SM, NULL, TopLevelRecord,
220            "structures containing pointers cannot be exported: '%0'");
221        return NULL;
222      }
223
224      const clang::PointerType *PT = UNSAFE_CAST_TYPE(clang::PointerType, T);
225      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
226
227      if (PointeeType->getTypeClass() == clang::Type::Pointer)
228        return T;
229      // We don't support pointer with array-type pointee or unsupported pointee
230      // type
231      if (PointeeType->isArrayType() ||
232          (TypeExportableHelper(PointeeType, SPS, Diags, SM, VD,
233                                TopLevelRecord) == NULL))
234        return NULL;
235      else
236        return T;
237    }
238    case clang::Type::ExtVector: {
239      const clang::ExtVectorType *EVT =
240          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
241      // Only vector with size 2, 3 and 4 are supported.
242      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
243        return NULL;
244
245      // Check base element type
246      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
247
248      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
249          (TypeExportableHelper(ElementType, SPS, Diags, SM, VD,
250                                TopLevelRecord) == NULL))
251        return NULL;
252      else
253        return T;
254    }
255    case clang::Type::ConstantArray: {
256      const clang::ConstantArrayType *CAT =
257          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
258
259      return ConstantArrayTypeExportableHelper(CAT, SPS, Diags, SM, VD,
260                                               TopLevelRecord);
261    }
262    default: {
263      return NULL;
264    }
265  }
266}
267
268// Return the type that can be used to create RSExportType, will always return
269// the canonical type
270// If the Type T is not exportable, this function returns NULL. Diags and SM
271// are used to generate proper Clang diagnostic messages when a
272// non-exportable type is detected. TopLevelRecord is used to capture the
273// highest struct (in the case of a nested hierarchy) for detecting other
274// types that cannot be exported (mostly pointers within a struct).
275static const clang::Type *TypeExportable(const clang::Type *T,
276                                         clang::Diagnostic *Diags,
277                                         clang::SourceManager *SM,
278                                         const clang::VarDecl *VD) {
279  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
280      llvm::SmallPtrSet<const clang::Type*, 8>();
281
282  return TypeExportableHelper(T, SPS, Diags, SM, VD, NULL);
283}
284
285}  // namespace
286
287/****************************** RSExportType ******************************/
288bool RSExportType::NormalizeType(const clang::Type *&T,
289                                 llvm::StringRef &TypeName,
290                                 clang::Diagnostic *Diags,
291                                 clang::SourceManager *SM,
292                                 const clang::VarDecl *VD) {
293  if ((T = TypeExportable(T, Diags, SM, VD)) == NULL) {
294    return false;
295  }
296  // Get type name
297  TypeName = RSExportType::GetTypeName(T);
298  if (TypeName.empty()) {
299    if (Diags && SM) {
300      if (VD) {
301        Diags->Report(clang::FullSourceLoc(VD->getLocation(), *SM),
302                      Diags->getCustomDiagID(clang::Diagnostic::Error,
303                                             "anonymous types cannot "
304                                             "be exported"));
305      } else {
306        Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
307                                             "anonymous types cannot "
308                                             "be exported"));
309      }
310    }
311    return false;
312  }
313
314  return true;
315}
316
317const clang::Type
318*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
319  if (DD) {
320    clang::QualType T;
321    if (DD->getTypeSourceInfo())
322      T = DD->getTypeSourceInfo()->getType();
323    else
324      T = DD->getType();
325
326    if (T.isNull())
327      return NULL;
328    else
329      return T.getTypePtr();
330  }
331  return NULL;
332}
333
334llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
335  T = GET_CANONICAL_TYPE(T);
336  if (T == NULL)
337    return llvm::StringRef();
338
339  switch (T->getTypeClass()) {
340    case clang::Type::Builtin: {
341      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
342
343      switch (BT->getKind()) {
344#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
345        case builtin_type:                                    \
346          return cname;                                       \
347        break;
348#include "RSClangBuiltinEnums.inc"
349        default: {
350          assert(false && "Unknown data type of the builtin");
351          break;
352        }
353      }
354      break;
355    }
356    case clang::Type::Record: {
357      clang::RecordDecl *RD;
358      if (T->isStructureType()) {
359        RD = T->getAsStructureType()->getDecl();
360      } else {
361        break;
362      }
363
364      llvm::StringRef Name = RD->getName();
365      if (Name.empty()) {
366          if (RD->getTypedefForAnonDecl() != NULL)
367            Name = RD->getTypedefForAnonDecl()->getName();
368
369          if (Name.empty())
370            // Try to find a name from redeclaration (i.e. typedef)
371            for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
372                     RE = RD->redecls_end();
373                 RI != RE;
374                 RI++) {
375              assert(*RI != NULL && "cannot be NULL object");
376
377              Name = (*RI)->getName();
378              if (!Name.empty())
379                break;
380            }
381      }
382      return Name;
383    }
384    case clang::Type::Pointer: {
385      // "*" plus pointee name
386      const clang::Type *PT = GET_POINTEE_TYPE(T);
387      llvm::StringRef PointeeName;
388      if (NormalizeType(PT, PointeeName, NULL, NULL, NULL)) {
389        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
390        Name[0] = '*';
391        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
392        Name[PointeeName.size() + 1] = '\0';
393        return Name;
394      }
395      break;
396    }
397    case clang::Type::ExtVector: {
398      const clang::ExtVectorType *EVT =
399          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
400      return RSExportVectorType::GetTypeName(EVT);
401      break;
402    }
403    case clang::Type::ConstantArray : {
404      // Construct name for a constant array is too complicated.
405      return DUMMY_TYPE_NAME_FOR_RS_CONSTANT_ARRAY_TYPE;
406    }
407    default: {
408      break;
409    }
410  }
411
412  return llvm::StringRef();
413}
414
415
416RSExportType *RSExportType::Create(RSContext *Context,
417                                   const clang::Type *T,
418                                   const llvm::StringRef &TypeName) {
419  // Lookup the context to see whether the type was processed before.
420  // Newly created RSExportType will insert into context
421  // in RSExportType::RSExportType()
422  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
423
424  if (ETI != Context->export_types_end())
425    return ETI->second;
426
427  RSExportType *ET = NULL;
428  switch (T->getTypeClass()) {
429    case clang::Type::Record: {
430      RSExportPrimitiveType::DataType dt =
431          RSExportPrimitiveType::GetRSSpecificType(TypeName);
432      switch (dt) {
433        case RSExportPrimitiveType::DataTypeUnknown: {
434          // User-defined types
435          ET = RSExportRecordType::Create(Context,
436                                          T->getAsStructureType(),
437                                          TypeName);
438          break;
439        }
440        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
441          // 2 x 2 Matrix type
442          ET = RSExportMatrixType::Create(Context,
443                                          T->getAsStructureType(),
444                                          TypeName,
445                                          2);
446          break;
447        }
448        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
449          // 3 x 3 Matrix type
450          ET = RSExportMatrixType::Create(Context,
451                                          T->getAsStructureType(),
452                                          TypeName,
453                                          3);
454          break;
455        }
456        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
457          // 4 x 4 Matrix type
458          ET = RSExportMatrixType::Create(Context,
459                                          T->getAsStructureType(),
460                                          TypeName,
461                                          4);
462          break;
463        }
464        default: {
465          // Others are primitive types
466          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
467          break;
468        }
469      }
470      break;
471    }
472    case clang::Type::Builtin: {
473      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
474      break;
475    }
476    case clang::Type::Pointer: {
477      ET = RSExportPointerType::Create(Context,
478                                       UNSAFE_CAST_TYPE(clang::PointerType, T),
479                                       TypeName);
480      // FIXME: free the name (allocated in RSExportType::GetTypeName)
481      delete [] TypeName.data();
482      break;
483    }
484    case clang::Type::ExtVector: {
485      ET = RSExportVectorType::Create(Context,
486                                      UNSAFE_CAST_TYPE(clang::ExtVectorType, T),
487                                      TypeName);
488      break;
489    }
490    case clang::Type::ConstantArray: {
491      ET = RSExportConstantArrayType::Create(
492              Context,
493              UNSAFE_CAST_TYPE(clang::ConstantArrayType, T));
494      break;
495    }
496    default: {
497      clang::Diagnostic *Diags = Context->getDiagnostics();
498      Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
499                        "unknown type cannot be exported: '%0'"))
500          << T->getTypeClassName();
501      break;
502    }
503  }
504
505  return ET;
506}
507
508RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
509  llvm::StringRef TypeName;
510  if (NormalizeType(T, TypeName, NULL, NULL, NULL))
511    return Create(Context, T, TypeName);
512  else
513    return NULL;
514}
515
516RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
517                                           const clang::VarDecl *VD) {
518  return RSExportType::Create(Context, GetTypeOfDecl(VD));
519}
520
521size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
522  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
523      ET->getLLVMType());
524}
525
526size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
527  if (ET->getClass() == RSExportType::ExportClassRecord)
528    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
529  else
530    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
531        ET->getLLVMType());
532}
533
534RSExportType::RSExportType(RSContext *Context,
535                           ExportClass Class,
536                           const llvm::StringRef &Name)
537    : RSExportable(Context, RSExportable::EX_TYPE),
538      mClass(Class),
539      // Make a copy on Name since memory stored @Name is either allocated in
540      // ASTContext or allocated in GetTypeName which will be destroyed later.
541      mName(Name.data(), Name.size()),
542      mLLVMType(NULL),
543      mSpecType(NULL) {
544  // Don't cache the type whose name start with '<'. Those type failed to
545  // get their name since constructing their name in GetTypeName() requiring
546  // complicated work.
547  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
548    // TODO(zonr): Need to check whether the insertion is successful or not.
549    Context->insertExportType(llvm::StringRef(Name), this);
550  return;
551}
552
553bool RSExportType::keep() {
554  if (!RSExportable::keep())
555    return false;
556  // Invalidate converted LLVM type.
557  mLLVMType = NULL;
558  return true;
559}
560
561bool RSExportType::equals(const RSExportable *E) const {
562  CHECK_PARENT_EQUALITY(RSExportable, E);
563  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
564}
565
566RSExportType::~RSExportType() {
567  delete mSpecType;
568}
569
570/************************** RSExportPrimitiveType **************************/
571llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
572RSExportPrimitiveType::RSSpecificTypeMap;
573
574llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
575
576bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
577  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
578    return true;
579  else
580    return false;
581}
582
583RSExportPrimitiveType::DataType
584RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
585  if (TypeName.empty())
586    return DataTypeUnknown;
587
588  if (RSSpecificTypeMap->empty()) {
589#define ENUM_RS_MATRIX_TYPE(type, cname, dim)                       \
590    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
591#include "RSMatrixTypeEnums.inc"
592#define ENUM_RS_OBJECT_TYPE(type, cname)                            \
593    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
594#include "RSObjectTypeEnums.inc"
595  }
596
597  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
598  if (I == RSSpecificTypeMap->end())
599    return DataTypeUnknown;
600  else
601    return I->getValue();
602}
603
604RSExportPrimitiveType::DataType
605RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
606  T = GET_CANONICAL_TYPE(T);
607  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
608    return DataTypeUnknown;
609
610  return GetRSSpecificType( RSExportType::GetTypeName(T) );
611}
612
613bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
614  return ((DT >= FirstRSMatrixType) && (DT <= LastRSMatrixType));
615}
616
617bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
618  return ((DT >= FirstRSObjectType) && (DT <= LastRSObjectType));
619}
620
621const size_t RSExportPrimitiveType::SizeOfDataTypeInBits[] = {
622#define ENUM_RS_DATA_TYPE(type, cname, bits)  \
623  bits,
624#include "RSDataTypeEnums.inc"
625  0   // DataTypeMax
626};
627
628size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
629  assert(((EPT->getType() > DataTypeUnknown) &&
630          (EPT->getType() < DataTypeMax)) &&
631         "RSExportPrimitiveType::GetSizeInBits : unknown data type");
632  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
633}
634
635RSExportPrimitiveType::DataType
636RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
637  if (T == NULL)
638    return DataTypeUnknown;
639
640  switch (T->getTypeClass()) {
641    case clang::Type::Builtin: {
642      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
643      switch (BT->getKind()) {
644#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
645        case builtin_type: {                                  \
646          return DataType ## type;                            \
647        }
648#include "RSClangBuiltinEnums.inc"
649        // The size of type WChar depend on platform so we abandon the support
650        // to them.
651        default: {
652          clang::Diagnostic *Diags = Context->getDiagnostics();
653          Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
654                            "built-in type cannot be exported: '%0'"))
655              << T->getTypeClassName();
656          break;
657        }
658      }
659      break;
660    }
661    case clang::Type::Record: {
662      // must be RS object type
663      return RSExportPrimitiveType::GetRSSpecificType(T);
664    }
665    default: {
666      clang::Diagnostic *Diags = Context->getDiagnostics();
667      Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
668                        "primitive type cannot be exported: '%0'"))
669          << T->getTypeClassName();
670      break;
671    }
672  }
673
674  return DataTypeUnknown;
675}
676
677RSExportPrimitiveType
678*RSExportPrimitiveType::Create(RSContext *Context,
679                               const clang::Type *T,
680                               const llvm::StringRef &TypeName,
681                               DataKind DK,
682                               bool Normalized) {
683  DataType DT = GetDataType(Context, T);
684
685  if ((DT == DataTypeUnknown) || TypeName.empty())
686    return NULL;
687  else
688    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
689                                     DT, DK, Normalized);
690}
691
692RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
693                                                     const clang::Type *T,
694                                                     DataKind DK) {
695  llvm::StringRef TypeName;
696  if (RSExportType::NormalizeType(T, TypeName, NULL, NULL, NULL) &&
697      IsPrimitiveType(T)) {
698    return Create(Context, T, TypeName, DK);
699  } else {
700    return NULL;
701  }
702}
703
704const llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
705  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
706
707  if (isRSObjectType()) {
708    // struct {
709    //   int *p;
710    // } __attribute__((packed, aligned(pointer_size)))
711    //
712    // which is
713    //
714    // <{ [1 x i32] }> in LLVM
715    //
716    if (RSObjectLLVMType == NULL) {
717      std::vector<const llvm::Type *> Elements;
718      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
719      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
720    }
721    return RSObjectLLVMType;
722  }
723
724  switch (mType) {
725    case DataTypeFloat32: {
726      return llvm::Type::getFloatTy(C);
727      break;
728    }
729    case DataTypeFloat64: {
730      return llvm::Type::getDoubleTy(C);
731      break;
732    }
733    case DataTypeBoolean: {
734      return llvm::Type::getInt1Ty(C);
735      break;
736    }
737    case DataTypeSigned8:
738    case DataTypeUnsigned8: {
739      return llvm::Type::getInt8Ty(C);
740      break;
741    }
742    case DataTypeSigned16:
743    case DataTypeUnsigned16:
744    case DataTypeUnsigned565:
745    case DataTypeUnsigned5551:
746    case DataTypeUnsigned4444: {
747      return llvm::Type::getInt16Ty(C);
748      break;
749    }
750    case DataTypeSigned32:
751    case DataTypeUnsigned32: {
752      return llvm::Type::getInt32Ty(C);
753      break;
754    }
755    case DataTypeSigned64:
756    case DataTypeUnsigned64: {
757      return llvm::Type::getInt64Ty(C);
758      break;
759    }
760    default: {
761      assert(false && "Unknown data type");
762    }
763  }
764
765  return NULL;
766}
767
768union RSType *RSExportPrimitiveType::convertToSpecType() const {
769  llvm::OwningPtr<union RSType> ST(new union RSType);
770  RS_TYPE_SET_CLASS(ST, RS_TC_Primitive);
771  // enum RSExportPrimitiveType::DataType is synced with enum RSDataType in
772  // slang_rs_type_spec.h
773  RS_PRIMITIVE_TYPE_SET_DATA_TYPE(ST, getType());
774  return ST.take();
775}
776
777bool RSExportPrimitiveType::equals(const RSExportable *E) const {
778  CHECK_PARENT_EQUALITY(RSExportType, E);
779  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
780}
781
782/**************************** RSExportPointerType ****************************/
783
784RSExportPointerType
785*RSExportPointerType::Create(RSContext *Context,
786                             const clang::PointerType *PT,
787                             const llvm::StringRef &TypeName) {
788  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
789  const RSExportType *PointeeET;
790
791  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
792    PointeeET = RSExportType::Create(Context, PointeeType);
793  } else {
794    // Double or higher dimension of pointer, export as int*
795    PointeeET = RSExportPrimitiveType::Create(Context,
796                    Context->getASTContext().IntTy.getTypePtr());
797  }
798
799  if (PointeeET == NULL) {
800    // Error diagnostic is emitted for corresponding pointee type
801    return NULL;
802  }
803
804  return new RSExportPointerType(Context, TypeName, PointeeET);
805}
806
807const llvm::Type *RSExportPointerType::convertToLLVMType() const {
808  const llvm::Type *PointeeType = mPointeeType->getLLVMType();
809  return llvm::PointerType::getUnqual(PointeeType);
810}
811
812union RSType *RSExportPointerType::convertToSpecType() const {
813  llvm::OwningPtr<union RSType> ST(new union RSType);
814
815  RS_TYPE_SET_CLASS(ST, RS_TC_Pointer);
816  RS_POINTER_TYPE_SET_POINTEE_TYPE(ST, getPointeeType()->getSpecType());
817
818  if (RS_POINTER_TYPE_GET_POINTEE_TYPE(ST) != NULL)
819    return ST.take();
820  else
821    return NULL;
822}
823
824bool RSExportPointerType::keep() {
825  if (!RSExportType::keep())
826    return false;
827  const_cast<RSExportType*>(mPointeeType)->keep();
828  return true;
829}
830
831bool RSExportPointerType::equals(const RSExportable *E) const {
832  CHECK_PARENT_EQUALITY(RSExportType, E);
833  return (static_cast<const RSExportPointerType*>(E)
834              ->getPointeeType()->equals(getPointeeType()));
835}
836
837/***************************** RSExportVectorType *****************************/
838llvm::StringRef
839RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
840  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
841
842  if ((ElementType->getTypeClass() != clang::Type::Builtin))
843    return llvm::StringRef();
844
845  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType,
846                                                  ElementType);
847  if ((EVT->getNumElements() < 1) ||
848      (EVT->getNumElements() > 4))
849    return llvm::StringRef();
850
851  switch (BT->getKind()) {
852    // Compiler is smart enough to optimize following *big if branches* since
853    // they all become "constant comparison" after macro expansion
854#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
855    case builtin_type: {                                      \
856      const char *Name[] = { cname"2", cname"3", cname"4" };  \
857      return Name[EVT->getNumElements() - 2];                 \
858      break;                                                  \
859    }
860#include "RSClangBuiltinEnums.inc"
861    default: {
862      return llvm::StringRef();
863    }
864  }
865}
866
867RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
868                                               const clang::ExtVectorType *EVT,
869                                               const llvm::StringRef &TypeName,
870                                               DataKind DK,
871                                               bool Normalized) {
872  assert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
873
874  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
875  RSExportPrimitiveType::DataType DT =
876      RSExportPrimitiveType::GetDataType(Context, ElementType);
877
878  if (DT != RSExportPrimitiveType::DataTypeUnknown)
879    return new RSExportVectorType(Context,
880                                  TypeName,
881                                  DT,
882                                  DK,
883                                  Normalized,
884                                  EVT->getNumElements());
885  else
886    return NULL;
887}
888
889const llvm::Type *RSExportVectorType::convertToLLVMType() const {
890  const llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
891  return llvm::VectorType::get(ElementType, getNumElement());
892}
893
894union RSType *RSExportVectorType::convertToSpecType() const {
895  llvm::OwningPtr<union RSType> ST(new union RSType);
896
897  RS_TYPE_SET_CLASS(ST, RS_TC_Vector);
898  RS_VECTOR_TYPE_SET_ELEMENT_TYPE(ST, getType());
899  RS_VECTOR_TYPE_SET_VECTOR_SIZE(ST, getNumElement());
900
901  return ST.take();
902}
903
904bool RSExportVectorType::equals(const RSExportable *E) const {
905  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
906  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
907              == getNumElement());
908}
909
910/***************************** RSExportMatrixType *****************************/
911RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
912                                               const clang::RecordType *RT,
913                                               const llvm::StringRef &TypeName,
914                                               unsigned Dim) {
915  assert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
916  assert((Dim > 1) && "Invalid dimension of matrix");
917
918  // Check whether the struct rs_matrix is in our expected form (but assume it's
919  // correct if we're not sure whether it's correct or not)
920  const clang::RecordDecl* RD = RT->getDecl();
921  RD = RD->getDefinition();
922  if (RD != NULL) {
923    clang::Diagnostic *Diags = Context->getDiagnostics();
924    const clang::SourceManager *SM = Context->getSourceManager();
925    // Find definition, perform further examination
926    if (RD->field_empty()) {
927      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
928                    Diags->getCustomDiagID(clang::Diagnostic::Error,
929                        "invalid matrix struct: must have 1 field for saving "
930                        "values: '%0'"))
931           << RD->getName();
932      return NULL;
933    }
934
935    clang::RecordDecl::field_iterator FIT = RD->field_begin();
936    const clang::FieldDecl *FD = *FIT;
937    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
938    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
939      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
940                    Diags->getCustomDiagID(clang::Diagnostic::Error,
941                        "invalid matrix struct: first field should be an "
942                        "array with constant size: '%0'"))
943           << RD->getName();
944      return NULL;
945    }
946    const clang::ConstantArrayType *CAT =
947      static_cast<const clang::ConstantArrayType *>(FT);
948    const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
949    if ((ElementType == NULL) ||
950        (ElementType->getTypeClass() != clang::Type::Builtin) ||
951        (static_cast<const clang::BuiltinType *>(ElementType)->getKind()
952          != clang::BuiltinType::Float)) {
953      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
954                    Diags->getCustomDiagID(clang::Diagnostic::Error,
955                        "invalid matrix struct: first field should be a "
956                        "float array: '%0'"))
957           << RD->getName();
958      return NULL;
959    }
960
961    if (CAT->getSize() != Dim * Dim) {
962      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
963                    Diags->getCustomDiagID(clang::Diagnostic::Error,
964                        "invalid matrix struct: first field should be an "
965                        "array with size %0: '%1'"))
966           << Dim * Dim
967           << RD->getName();
968      return NULL;
969    }
970
971    FIT++;
972    if (FIT != RD->field_end()) {
973      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
974                    Diags->getCustomDiagID(clang::Diagnostic::Error,
975                        "invalid matrix struct: must have exactly 1 field: "
976                        "'%0'"))
977           << RD->getName();
978      return NULL;
979    }
980  }
981
982  return new RSExportMatrixType(Context, TypeName, Dim);
983}
984
985const llvm::Type *RSExportMatrixType::convertToLLVMType() const {
986  // Construct LLVM type:
987  // struct {
988  //  float X[mDim * mDim];
989  // }
990
991  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
992  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
993                                            mDim * mDim);
994  return llvm::StructType::get(C, X, NULL);
995}
996
997union RSType *RSExportMatrixType::convertToSpecType() const {
998  llvm::OwningPtr<union RSType> ST(new union RSType);
999  RS_TYPE_SET_CLASS(ST, RS_TC_Matrix);
1000  switch (getDim()) {
1001    case 2: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix2x2); break;
1002    case 3: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix3x3); break;
1003    case 4: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix4x4); break;
1004    default: assert(false && "Matrix type with unsupported dimension.");
1005  }
1006  return ST.take();
1007}
1008
1009bool RSExportMatrixType::equals(const RSExportable *E) const {
1010  CHECK_PARENT_EQUALITY(RSExportType, E);
1011  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1012}
1013
1014/************************* RSExportConstantArrayType *************************/
1015RSExportConstantArrayType
1016*RSExportConstantArrayType::Create(RSContext *Context,
1017                                   const clang::ConstantArrayType *CAT) {
1018  assert(CAT != NULL && CAT->getTypeClass() == clang::Type::ConstantArray);
1019
1020  assert((CAT->getSize().getActiveBits() < 32) && "array too large");
1021
1022  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1023  assert((Size > 0) && "Constant array should have size greater than 0");
1024
1025  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
1026  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
1027
1028  if (ElementET == NULL) {
1029    return NULL;
1030  }
1031
1032  return new RSExportConstantArrayType(Context,
1033                                       ElementET,
1034                                       Size);
1035}
1036
1037const llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1038  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
1039}
1040
1041union RSType *RSExportConstantArrayType::convertToSpecType() const {
1042  llvm::OwningPtr<union RSType> ST(new union RSType);
1043
1044  RS_TYPE_SET_CLASS(ST, RS_TC_ConstantArray);
1045  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_TYPE(
1046      ST, getElementType()->getSpecType());
1047  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_SIZE(ST, getSize());
1048
1049  if (RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(ST) != NULL)
1050    return ST.take();
1051  else
1052    return NULL;
1053}
1054
1055bool RSExportConstantArrayType::keep() {
1056  if (!RSExportType::keep())
1057    return false;
1058  const_cast<RSExportType*>(mElementType)->keep();
1059  return true;
1060}
1061
1062bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1063  CHECK_PARENT_EQUALITY(RSExportType, E);
1064  const RSExportConstantArrayType *RHS =
1065      static_cast<const RSExportConstantArrayType*>(E);
1066  return ((getSize() == RHS->getSize()) &&
1067          (getElementType()->equals(RHS->getElementType())));
1068}
1069
1070/**************************** RSExportRecordType ****************************/
1071RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1072                                               const clang::RecordType *RT,
1073                                               const llvm::StringRef &TypeName,
1074                                               bool mIsArtificial) {
1075  assert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
1076
1077  const clang::RecordDecl *RD = RT->getDecl();
1078  assert(RD->isStruct());
1079
1080  RD = RD->getDefinition();
1081  if (RD == NULL) {
1082    assert(false && "struct is not defined in this module");
1083    return NULL;
1084  }
1085
1086  // Struct layout construct by clang. We rely on this for obtaining the
1087  // alloc size of a struct and offset of every field in that struct.
1088  const clang::ASTRecordLayout *RL =
1089      &Context->getASTContext().getASTRecordLayout(RD);
1090  assert((RL != NULL) && "Failed to retrieve the struct layout from Clang.");
1091
1092  RSExportRecordType *ERT =
1093      new RSExportRecordType(Context,
1094                             TypeName,
1095                             RD->hasAttr<clang::PackedAttr>(),
1096                             mIsArtificial,
1097                             (RL->getSize() >> 3));
1098  unsigned int Index = 0;
1099
1100  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1101           FE = RD->field_end();
1102       FI != FE;
1103       FI++, Index++) {
1104    clang::Diagnostic *Diags = Context->getDiagnostics();
1105    const clang::SourceManager *SM = Context->getSourceManager();
1106
1107    // FIXME: All fields should be primitive type
1108    assert((*FI)->getKind() == clang::Decl::Field);
1109    clang::FieldDecl *FD = *FI;
1110
1111    if (FD->isBitField()) {
1112      delete ERT;
1113      return NULL;
1114    }
1115
1116    // Type
1117    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1118
1119    if (ET != NULL) {
1120      ERT->mFields.push_back(
1121          new Field(ET, FD->getName(), ERT,
1122                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1123    } else {
1124      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
1125                    Diags->getCustomDiagID(clang::Diagnostic::Error,
1126                    "field type cannot be exported: '%0.%1'"))
1127          << RD->getName()
1128          << FD->getName();
1129      delete ERT;
1130      return NULL;
1131    }
1132  }
1133
1134  return ERT;
1135}
1136
1137const llvm::Type *RSExportRecordType::convertToLLVMType() const {
1138  // Create an opaque type since struct may reference itself recursively.
1139  llvm::PATypeHolder ResultHolder =
1140      llvm::OpaqueType::get(getRSContext()->getLLVMContext());
1141  setAbstractLLVMType(ResultHolder.get());
1142
1143  std::vector<const llvm::Type*> FieldTypes;
1144
1145  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1146       FI != FE;
1147       FI++) {
1148    const Field *F = *FI;
1149    const RSExportType *FET = F->getType();
1150
1151    FieldTypes.push_back(FET->getLLVMType());
1152  }
1153
1154  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1155                                               FieldTypes,
1156                                               mIsPacked);
1157  if (ST != NULL)
1158    static_cast<llvm::OpaqueType*>(ResultHolder.get())
1159        ->refineAbstractTypeTo(ST);
1160  else
1161    return NULL;
1162  return ResultHolder.get();
1163}
1164
1165union RSType *RSExportRecordType::convertToSpecType() const {
1166  unsigned NumFields = getFields().size();
1167  unsigned AllocSize = sizeof(union RSType) +
1168                       sizeof(struct RSRecordField) * NumFields;
1169  llvm::OwningPtr<union RSType> ST(
1170      reinterpret_cast<union RSType*>(operator new(AllocSize)));
1171
1172  ::memset(ST.get(), 0, AllocSize);
1173
1174  RS_TYPE_SET_CLASS(ST, RS_TC_Record);
1175  RS_RECORD_TYPE_SET_NAME(ST, getName().c_str());
1176  RS_RECORD_TYPE_SET_NUM_FIELDS(ST, NumFields);
1177
1178  setSpecTypeTemporarily(ST.get());
1179
1180  unsigned FieldIdx = 0;
1181  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1182       FI != FE;
1183       FI++, FieldIdx++) {
1184    const Field *F = *FI;
1185
1186    RS_RECORD_TYPE_SET_FIELD_NAME(ST, FieldIdx, F->getName().c_str());
1187    RS_RECORD_TYPE_SET_FIELD_TYPE(ST, FieldIdx, F->getType()->getSpecType());
1188
1189    enum RSDataKind DK = RS_DK_User;
1190    if ((F->getType()->getClass() == ExportClassPrimitive) ||
1191        (F->getType()->getClass() == ExportClassVector)) {
1192      const RSExportPrimitiveType *EPT =
1193        static_cast<const RSExportPrimitiveType*>(F->getType());
1194      // enum RSExportPrimitiveType::DataKind is synced with enum RSDataKind in
1195      // slang_rs_type_spec.h
1196      DK = static_cast<enum RSDataKind>(EPT->getKind());
1197    }
1198    RS_RECORD_TYPE_SET_FIELD_DATA_KIND(ST, FieldIdx, DK);
1199  }
1200
1201  // TODO(slang): Check whether all fields were created normally.
1202
1203  return ST.take();
1204}
1205
1206bool RSExportRecordType::keep() {
1207  if (!RSExportType::keep())
1208    return false;
1209  for (std::list<const Field*>::iterator I = mFields.begin(),
1210          E = mFields.end();
1211       I != E;
1212       I++) {
1213    const_cast<RSExportType*>((*I)->getType())->keep();
1214  }
1215  return true;
1216}
1217
1218bool RSExportRecordType::equals(const RSExportable *E) const {
1219  CHECK_PARENT_EQUALITY(RSExportType, E);
1220
1221  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1222
1223  if (ERT->getFields().size() != getFields().size())
1224    return false;
1225
1226  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1227
1228  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1229    if (!(*AI)->getType()->equals((*BI)->getType()))
1230      return false;
1231    AI++;
1232    BI++;
1233  }
1234
1235  return true;
1236}
1237
1238}  // namespace slang
1239