slang_rs_export_type.cpp revision fe41c8da501e5a742f1aa6fade2592b4b80bca3b
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/ASTContext.h"
23#include "clang/AST/Attr.h"
24#include "clang/AST/RecordLayout.h"
25
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/DerivedTypes.h"
29#include "llvm/IR/Type.h"
30
31#include "slang_assert.h"
32#include "slang_rs_context.h"
33#include "slang_rs_export_element.h"
34#include "slang_version.h"
35
36#define CHECK_PARENT_EQUALITY(ParentClass, E) \
37  if (!ParentClass::equals(E))                \
38    return false;
39
40namespace slang {
41
42namespace {
43
44/* For the data types we support, their category, names, and size (in bits).
45 *
46 * IMPORTANT: The data types in this table should be at the same index
47 * as specified by the corresponding DataType enum.
48 */
49static RSReflectionType gReflectionTypes[] = {
50    {PrimitiveDataType, "FLOAT_16", "F16", 16, "half", "half", "Half", "Half", false},
51    {PrimitiveDataType, "FLOAT_32", "F32", 32, "float", "float", "Float", "Float", false},
52    {PrimitiveDataType, "FLOAT_64", "F64", 64, "double", "double", "Double", "Double",false},
53    {PrimitiveDataType, "SIGNED_8", "I8", 8, "int8_t", "byte", "Byte", "Byte", false},
54    {PrimitiveDataType, "SIGNED_16", "I16", 16, "int16_t", "short", "Short", "Short", false},
55    {PrimitiveDataType, "SIGNED_32", "I32", 32, "int32_t", "int", "Int", "Int", false},
56    {PrimitiveDataType, "SIGNED_64", "I64", 64, "int64_t", "long", "Long", "Long", false},
57    {PrimitiveDataType, "UNSIGNED_8", "U8", 8, "uint8_t", "short", "UByte", "Short", true},
58    {PrimitiveDataType, "UNSIGNED_16", "U16", 16, "uint16_t", "int", "UShort", "Int", true},
59    {PrimitiveDataType, "UNSIGNED_32", "U32", 32, "uint32_t", "long", "UInt", "Long", true},
60    {PrimitiveDataType, "UNSIGNED_64", "U64", 64, "uint64_t", "long", "ULong", "Long", false},
61
62    {PrimitiveDataType, "BOOLEAN", "BOOLEAN", 8, "bool", "boolean", nullptr, nullptr, false},
63
64    {PrimitiveDataType, "UNSIGNED_5_6_5", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
65    {PrimitiveDataType, "UNSIGNED_5_5_5_1", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
66    {PrimitiveDataType, "UNSIGNED_4_4_4_4", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
67
68    {MatrixDataType, "MATRIX_2X2", nullptr, 4*32, "rsMatrix_2x2", "Matrix2f", nullptr, nullptr, false},
69    {MatrixDataType, "MATRIX_3X3", nullptr, 9*32, "rsMatrix_3x3", "Matrix3f", nullptr, nullptr, false},
70    {MatrixDataType, "MATRIX_4X4", nullptr, 16*32, "rsMatrix_4x4", "Matrix4f", nullptr, nullptr, false},
71
72    // RS object types are 32 bits in 32-bit RS, but 256 bits in 64-bit RS.
73    // This is handled specially by the GetSizeInBits() method.
74    {ObjectDataType, "RS_ELEMENT", "ELEMENT", 32, "Element", "Element", nullptr, nullptr, false},
75    {ObjectDataType, "RS_TYPE", "TYPE", 32, "Type", "Type", nullptr, nullptr, false},
76    {ObjectDataType, "RS_ALLOCATION", "ALLOCATION", 32, "Allocation", "Allocation", nullptr, nullptr, false},
77    {ObjectDataType, "RS_SAMPLER", "SAMPLER", 32, "Sampler", "Sampler", nullptr, nullptr, false},
78    {ObjectDataType, "RS_SCRIPT", "SCRIPT", 32, "Script", "Script", nullptr, nullptr, false},
79    {ObjectDataType, "RS_MESH", "MESH", 32, "Mesh", "Mesh", nullptr, nullptr, false},
80    {ObjectDataType, "RS_PATH", "PATH", 32, "Path", "Path", nullptr, nullptr, false},
81
82    {ObjectDataType, "RS_PROGRAM_FRAGMENT", "PROGRAM_FRAGMENT", 32, "ProgramFragment", "ProgramFragment", nullptr, nullptr, false},
83    {ObjectDataType, "RS_PROGRAM_VERTEX", "PROGRAM_VERTEX", 32, "ProgramVertex", "ProgramVertex", nullptr, nullptr, false},
84    {ObjectDataType, "RS_PROGRAM_RASTER", "PROGRAM_RASTER", 32, "ProgramRaster", "ProgramRaster", nullptr, nullptr, false},
85    {ObjectDataType, "RS_PROGRAM_STORE", "PROGRAM_STORE", 32, "ProgramStore", "ProgramStore", nullptr, nullptr, false},
86    {ObjectDataType, "RS_FONT", "FONT", 32, "Font", "Font", nullptr, nullptr, false}
87};
88
89const int kMaxVectorSize = 4;
90
91struct BuiltinInfo {
92  clang::BuiltinType::Kind builtinTypeKind;
93  DataType type;
94  /* TODO If we return std::string instead of llvm::StringRef, we could build
95   * the name instead of duplicating the entries.
96   */
97  const char *cname[kMaxVectorSize];
98};
99
100
101BuiltinInfo BuiltinInfoTable[] = {
102    {clang::BuiltinType::Bool, DataTypeBoolean,
103     {"bool", "bool2", "bool3", "bool4"}},
104    {clang::BuiltinType::Char_U, DataTypeUnsigned8,
105     {"uchar", "uchar2", "uchar3", "uchar4"}},
106    {clang::BuiltinType::UChar, DataTypeUnsigned8,
107     {"uchar", "uchar2", "uchar3", "uchar4"}},
108    {clang::BuiltinType::Char16, DataTypeSigned16,
109     {"short", "short2", "short3", "short4"}},
110    {clang::BuiltinType::Char32, DataTypeSigned32,
111     {"int", "int2", "int3", "int4"}},
112    {clang::BuiltinType::UShort, DataTypeUnsigned16,
113     {"ushort", "ushort2", "ushort3", "ushort4"}},
114    {clang::BuiltinType::UInt, DataTypeUnsigned32,
115     {"uint", "uint2", "uint3", "uint4"}},
116    {clang::BuiltinType::ULong, DataTypeUnsigned64,
117     {"ulong", "ulong2", "ulong3", "ulong4"}},
118    {clang::BuiltinType::ULongLong, DataTypeUnsigned64,
119     {"ulong", "ulong2", "ulong3", "ulong4"}},
120
121    {clang::BuiltinType::Char_S, DataTypeSigned8,
122     {"char", "char2", "char3", "char4"}},
123    {clang::BuiltinType::SChar, DataTypeSigned8,
124     {"char", "char2", "char3", "char4"}},
125    {clang::BuiltinType::Short, DataTypeSigned16,
126     {"short", "short2", "short3", "short4"}},
127    {clang::BuiltinType::Int, DataTypeSigned32,
128     {"int", "int2", "int3", "int4"}},
129    {clang::BuiltinType::Long, DataTypeSigned64,
130     {"long", "long2", "long3", "long4"}},
131    {clang::BuiltinType::LongLong, DataTypeSigned64,
132     {"long", "long2", "long3", "long4"}},
133    {clang::BuiltinType::Float, DataTypeFloat32,
134     {"float", "float2", "float3", "float4"}},
135    {clang::BuiltinType::Double, DataTypeFloat64,
136     {"double", "double2", "double3", "double4"}},
137};
138const int BuiltinInfoTableCount = sizeof(BuiltinInfoTable) / sizeof(BuiltinInfoTable[0]);
139
140struct NameAndPrimitiveType {
141  const char *name;
142  DataType dataType;
143};
144
145static NameAndPrimitiveType MatrixAndObjectDataTypes[] = {
146    {"rs_matrix2x2", DataTypeRSMatrix2x2},
147    {"rs_matrix3x3", DataTypeRSMatrix3x3},
148    {"rs_matrix4x4", DataTypeRSMatrix4x4},
149    {"rs_element", DataTypeRSElement},
150    {"rs_type", DataTypeRSType},
151    {"rs_allocation", DataTypeRSAllocation},
152    {"rs_sampler", DataTypeRSSampler},
153    {"rs_script", DataTypeRSScript},
154    {"rs_mesh", DataTypeRSMesh},
155    {"rs_path", DataTypeRSPath},
156    {"rs_program_fragment", DataTypeRSProgramFragment},
157    {"rs_program_vertex", DataTypeRSProgramVertex},
158    {"rs_program_raster", DataTypeRSProgramRaster},
159    {"rs_program_store", DataTypeRSProgramStore},
160    {"rs_font", DataTypeRSFont},
161};
162
163const int MatrixAndObjectDataTypesCount =
164    sizeof(MatrixAndObjectDataTypes) / sizeof(MatrixAndObjectDataTypes[0]);
165
166static const clang::Type *TypeExportableHelper(
167    const clang::Type *T,
168    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
169    slang::RSContext *Context,
170    const clang::VarDecl *VD,
171    const clang::RecordDecl *TopLevelRecord);
172
173template <unsigned N>
174static void ReportTypeError(slang::RSContext *Context,
175                            const clang::NamedDecl *ND,
176                            const clang::RecordDecl *TopLevelRecord,
177                            const char (&Message)[N],
178                            unsigned int TargetAPI = 0) {
179  // Attempt to use the type declaration first (if we have one).
180  // Fall back to the variable definition, if we are looking at something
181  // like an array declaration that can't be exported.
182  if (TopLevelRecord) {
183    Context->ReportError(TopLevelRecord->getLocation(), Message)
184        << TopLevelRecord->getName() << TargetAPI;
185  } else if (ND) {
186    Context->ReportError(ND->getLocation(), Message) << ND->getName()
187                                                     << TargetAPI;
188  } else {
189    slangAssert(false && "Variables should be validated before exporting");
190  }
191}
192
193static const clang::Type *ConstantArrayTypeExportableHelper(
194    const clang::ConstantArrayType *CAT,
195    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
196    slang::RSContext *Context,
197    const clang::VarDecl *VD,
198    const clang::RecordDecl *TopLevelRecord) {
199  // Check element type
200  const clang::Type *ElementType = GetConstantArrayElementType(CAT);
201  if (ElementType->isArrayType()) {
202    ReportTypeError(Context, VD, TopLevelRecord,
203                    "multidimensional arrays cannot be exported: '%0'");
204    return nullptr;
205  } else if (ElementType->isExtVectorType()) {
206    const clang::ExtVectorType *EVT =
207        static_cast<const clang::ExtVectorType*>(ElementType);
208    unsigned numElements = EVT->getNumElements();
209
210    const clang::Type *BaseElementType = GetExtVectorElementType(EVT);
211    if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
212      ReportTypeError(Context, VD, TopLevelRecord,
213        "vectors of non-primitive types cannot be exported: '%0'");
214      return nullptr;
215    }
216
217    if (numElements == 3 && CAT->getSize() != 1) {
218      ReportTypeError(Context, VD, TopLevelRecord,
219        "arrays of width 3 vector types cannot be exported: '%0'");
220      return nullptr;
221    }
222  }
223
224  if (TypeExportableHelper(ElementType, SPS, Context, VD,
225                           TopLevelRecord) == nullptr) {
226    return nullptr;
227  } else {
228    return CAT;
229  }
230}
231
232BuiltinInfo *FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind) {
233  for (int i = 0; i < BuiltinInfoTableCount; i++) {
234    if (builtinTypeKind == BuiltinInfoTable[i].builtinTypeKind) {
235      return &BuiltinInfoTable[i];
236    }
237  }
238  return nullptr;
239}
240
241static const clang::Type *TypeExportableHelper(
242    clang::Type const *T,
243    llvm::SmallPtrSet<clang::Type const *, 8> &SPS,
244    slang::RSContext *Context,
245    clang::VarDecl const *VD,
246    clang::RecordDecl const *TopLevelRecord) {
247  // Normalize first
248  if ((T = GetCanonicalType(T)) == nullptr)
249    return nullptr;
250
251  if (SPS.count(T))
252    return T;
253
254  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
255
256  switch (T->getTypeClass()) {
257    case clang::Type::Builtin: {
258      const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
259      return FindBuiltinType(BT->getKind()) == nullptr ? nullptr : T;
260    }
261    case clang::Type::Record: {
262      if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
263        return T;  // RS object type, no further checks are needed
264      }
265
266      // Check internal struct
267      if (T->isUnionType()) {
268        ReportTypeError(Context, VD, T->getAsUnionType()->getDecl(),
269                        "unions cannot be exported: '%0'");
270        return nullptr;
271      } else if (!T->isStructureType()) {
272        slangAssert(false && "Unknown type cannot be exported");
273        return nullptr;
274      }
275
276      clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
277      if (RD != nullptr) {
278        RD = RD->getDefinition();
279        if (RD == nullptr) {
280          ReportTypeError(Context, nullptr, T->getAsStructureType()->getDecl(),
281                          "struct is not defined in this module");
282          return nullptr;
283        }
284      }
285
286      if (!TopLevelRecord) {
287        TopLevelRecord = RD;
288      }
289      if (RD->getName().empty()) {
290        ReportTypeError(Context, nullptr, RD,
291                        "anonymous structures cannot be exported");
292        return nullptr;
293      }
294
295      // Fast check
296      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
297        return nullptr;
298
299      // Insert myself into checking set
300      SPS.insert(T);
301
302      // Check all element
303      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
304               FE = RD->field_end();
305           FI != FE;
306           FI++) {
307        const clang::FieldDecl *FD = *FI;
308        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
309        FT = GetCanonicalType(FT);
310
311        if (!TypeExportableHelper(FT, SPS, Context, VD, TopLevelRecord)) {
312          return nullptr;
313        }
314
315        // We don't support bit fields yet
316        //
317        // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
318        if (FD->isBitField()) {
319          Context->ReportError(
320              FD->getLocation(),
321              "bit fields are not able to be exported: '%0.%1'")
322              << RD->getName() << FD->getName();
323          return nullptr;
324        }
325      }
326
327      return T;
328    }
329    case clang::Type::Pointer: {
330      if (TopLevelRecord) {
331        ReportTypeError(Context, VD, TopLevelRecord,
332            "structures containing pointers cannot be exported: '%0'");
333        return nullptr;
334      }
335
336      const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
337      const clang::Type *PointeeType = GetPointeeType(PT);
338
339      if (PointeeType->getTypeClass() == clang::Type::Pointer) {
340        ReportTypeError(Context, VD, TopLevelRecord,
341            "multiple levels of pointers cannot be exported: '%0'");
342        return nullptr;
343      }
344      // We don't support pointer with array-type pointee or unsupported pointee
345      // type
346      if (PointeeType->isArrayType() ||
347          (TypeExportableHelper(PointeeType, SPS, Context, VD,
348                                TopLevelRecord) == nullptr))
349        return nullptr;
350      else
351        return T;
352    }
353    case clang::Type::ExtVector: {
354      const clang::ExtVectorType *EVT =
355              static_cast<const clang::ExtVectorType*>(CTI);
356      // Only vector with size 2, 3 and 4 are supported.
357      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
358        return nullptr;
359
360      // Check base element type
361      const clang::Type *ElementType = GetExtVectorElementType(EVT);
362
363      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
364          (TypeExportableHelper(ElementType, SPS, Context, VD,
365                                TopLevelRecord) == nullptr))
366        return nullptr;
367      else
368        return T;
369    }
370    case clang::Type::ConstantArray: {
371      const clang::ConstantArrayType *CAT =
372              static_cast<const clang::ConstantArrayType*>(CTI);
373
374      return ConstantArrayTypeExportableHelper(CAT, SPS, Context, VD,
375                                               TopLevelRecord);
376    }
377    case clang::Type::Enum: {
378      // FIXME: We currently convert enums to integers, rather than reflecting
379      // a more complete (and nicer type-safe Java version).
380      return Context->getASTContext().IntTy.getTypePtr();
381    }
382    default: {
383      slangAssert(false && "Unknown type cannot be validated");
384      return nullptr;
385    }
386  }
387}
388
389// Return the type that can be used to create RSExportType, will always return
390// the canonical type.
391//
392// If the Type T is not exportable, this function returns nullptr. DiagEngine is
393// used to generate proper Clang diagnostic messages when a non-exportable type
394// is detected. TopLevelRecord is used to capture the highest struct (in the
395// case of a nested hierarchy) for detecting other types that cannot be exported
396// (mostly pointers within a struct).
397static const clang::Type *TypeExportable(const clang::Type *T,
398                                         slang::RSContext *Context,
399                                         const clang::VarDecl *VD) {
400  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
401      llvm::SmallPtrSet<const clang::Type*, 8>();
402
403  return TypeExportableHelper(T, SPS, Context, VD, nullptr);
404}
405
406static bool ValidateRSObjectInVarDecl(slang::RSContext *Context,
407                                      clang::VarDecl *VD, bool InCompositeType,
408                                      unsigned int TargetAPI) {
409  if (TargetAPI < SLANG_JB_TARGET_API) {
410    // Only if we are already in a composite type (like an array or structure).
411    if (InCompositeType) {
412      // Only if we are actually exported (i.e. non-static).
413      if (VD->hasLinkage() &&
414          (VD->getFormalLinkage() == clang::ExternalLinkage)) {
415        // Only if we are not a pointer to an object.
416        const clang::Type *T = GetCanonicalType(VD->getType().getTypePtr());
417        if (T->getTypeClass() != clang::Type::Pointer) {
418          ReportTypeError(Context, VD, nullptr,
419                          "arrays/structures containing RS object types "
420                          "cannot be exported in target API < %1: '%0'",
421                          SLANG_JB_TARGET_API);
422          return false;
423        }
424      }
425    }
426  }
427
428  return true;
429}
430
431// Helper function for ValidateType(). We do a recursive descent on the
432// type hierarchy to ensure that we can properly export/handle the
433// declaration.
434// \return true if the variable declaration is valid,
435//         false if it is invalid (along with proper diagnostics).
436//
437// C - ASTContext (for diagnostics + builtin types).
438// T - sub-type that we are validating.
439// ND - (optional) top-level named declaration that we are validating.
440// SPS - set of types we have already seen/validated.
441// InCompositeType - true if we are within an outer composite type.
442// UnionDecl - set if we are in a sub-type of a union.
443// TargetAPI - target SDK API level.
444// IsFilterscript - whether or not we are compiling for Filterscript
445static bool ValidateTypeHelper(
446    slang::RSContext *Context,
447    clang::ASTContext &C,
448    const clang::Type *&T,
449    clang::NamedDecl *ND,
450    clang::SourceLocation Loc,
451    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
452    bool InCompositeType,
453    clang::RecordDecl *UnionDecl,
454    unsigned int TargetAPI,
455    bool IsFilterscript) {
456  if ((T = GetCanonicalType(T)) == nullptr)
457    return true;
458
459  if (SPS.count(T))
460    return true;
461
462  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
463
464  switch (T->getTypeClass()) {
465    case clang::Type::Record: {
466      if (RSExportPrimitiveType::IsRSObjectType(T)) {
467        clang::VarDecl *VD = (ND ? llvm::dyn_cast<clang::VarDecl>(ND) : nullptr);
468        if (VD && !ValidateRSObjectInVarDecl(Context, VD, InCompositeType,
469                                             TargetAPI)) {
470          return false;
471        }
472      }
473
474      if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
475        if (!UnionDecl) {
476          return true;
477        } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
478          ReportTypeError(Context, nullptr, UnionDecl,
479              "unions containing RS object types are not allowed");
480          return false;
481        }
482      }
483
484      clang::RecordDecl *RD = nullptr;
485
486      // Check internal struct
487      if (T->isUnionType()) {
488        RD = T->getAsUnionType()->getDecl();
489        UnionDecl = RD;
490      } else if (T->isStructureType()) {
491        RD = T->getAsStructureType()->getDecl();
492      } else {
493        slangAssert(false && "Unknown type cannot be exported");
494        return false;
495      }
496
497      if (RD != nullptr) {
498        RD = RD->getDefinition();
499        if (RD == nullptr) {
500          // FIXME
501          return true;
502        }
503      }
504
505      // Fast check
506      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
507        return false;
508
509      // Insert myself into checking set
510      SPS.insert(T);
511
512      // Check all elements
513      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
514               FE = RD->field_end();
515           FI != FE;
516           FI++) {
517        const clang::FieldDecl *FD = *FI;
518        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
519        FT = GetCanonicalType(FT);
520
521        if (!ValidateTypeHelper(Context, C, FT, ND, Loc, SPS, true, UnionDecl,
522                                TargetAPI, IsFilterscript)) {
523          return false;
524        }
525      }
526
527      return true;
528    }
529
530    case clang::Type::Builtin: {
531      if (IsFilterscript) {
532        clang::QualType QT = T->getCanonicalTypeInternal();
533        if (QT == C.DoubleTy ||
534            QT == C.LongDoubleTy ||
535            QT == C.LongTy ||
536            QT == C.LongLongTy) {
537          if (ND) {
538            Context->ReportError(
539                Loc,
540                "Builtin types > 32 bits in size are forbidden in "
541                "Filterscript: '%0'")
542                << ND->getName();
543          } else {
544            Context->ReportError(
545                Loc,
546                "Builtin types > 32 bits in size are forbidden in "
547                "Filterscript");
548          }
549          return false;
550        }
551      }
552      break;
553    }
554
555    case clang::Type::Pointer: {
556      if (IsFilterscript) {
557        if (ND) {
558          Context->ReportError(Loc,
559                               "Pointers are forbidden in Filterscript: '%0'")
560              << ND->getName();
561          return false;
562        } else {
563          // TODO(srhines): Find a better way to handle expressions (i.e. no
564          // NamedDecl) involving pointers in FS that should be allowed.
565          // An example would be calls to library functions like
566          // rsMatrixMultiply() that take rs_matrixNxN * types.
567        }
568      }
569
570      const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
571      const clang::Type *PointeeType = GetPointeeType(PT);
572
573      return ValidateTypeHelper(Context, C, PointeeType, ND, Loc, SPS,
574                                InCompositeType, UnionDecl, TargetAPI,
575                                IsFilterscript);
576    }
577
578    case clang::Type::ExtVector: {
579      const clang::ExtVectorType *EVT =
580              static_cast<const clang::ExtVectorType*>(CTI);
581      const clang::Type *ElementType = GetExtVectorElementType(EVT);
582      if (TargetAPI < SLANG_ICS_TARGET_API &&
583          InCompositeType &&
584          EVT->getNumElements() == 3 &&
585          ND &&
586          ND->getFormalLinkage() == clang::ExternalLinkage) {
587        ReportTypeError(Context, ND, nullptr,
588                        "structs containing vectors of dimension 3 cannot "
589                        "be exported at this API level: '%0'");
590        return false;
591      }
592      return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
593                                UnionDecl, TargetAPI, IsFilterscript);
594    }
595
596    case clang::Type::ConstantArray: {
597      const clang::ConstantArrayType *CAT = static_cast<const clang::ConstantArrayType*>(CTI);
598      const clang::Type *ElementType = GetConstantArrayElementType(CAT);
599      return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
600                                UnionDecl, TargetAPI, IsFilterscript);
601    }
602
603    default: {
604      break;
605    }
606  }
607
608  return true;
609}
610
611}  // namespace
612
613std::string CreateDummyName(const char *type, const std::string &name) {
614  std::stringstream S;
615  S << "<" << type;
616  if (!name.empty()) {
617    S << ":" << name;
618  }
619  S << ">";
620  return S.str();
621}
622
623/****************************** RSExportType ******************************/
624bool RSExportType::NormalizeType(const clang::Type *&T,
625                                 llvm::StringRef &TypeName,
626                                 RSContext *Context,
627                                 const clang::VarDecl *VD) {
628  if ((T = TypeExportable(T, Context, VD)) == nullptr) {
629    return false;
630  }
631  // Get type name
632  TypeName = RSExportType::GetTypeName(T);
633  if (Context && TypeName.empty()) {
634    if (VD) {
635      Context->ReportError(VD->getLocation(),
636                           "anonymous types cannot be exported");
637    } else {
638      Context->ReportError("anonymous types cannot be exported");
639    }
640    return false;
641  }
642
643  return true;
644}
645
646bool RSExportType::ValidateType(slang::RSContext *Context, clang::ASTContext &C,
647                                clang::QualType QT, clang::NamedDecl *ND,
648                                clang::SourceLocation Loc,
649                                unsigned int TargetAPI, bool IsFilterscript) {
650  const clang::Type *T = QT.getTypePtr();
651  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
652      llvm::SmallPtrSet<const clang::Type*, 8>();
653
654  return ValidateTypeHelper(Context, C, T, ND, Loc, SPS, false, nullptr, TargetAPI,
655                            IsFilterscript);
656  return true;
657}
658
659bool RSExportType::ValidateVarDecl(slang::RSContext *Context,
660                                   clang::VarDecl *VD, unsigned int TargetAPI,
661                                   bool IsFilterscript) {
662  return ValidateType(Context, VD->getASTContext(), VD->getType(), VD,
663                      VD->getLocation(), TargetAPI, IsFilterscript);
664}
665
666const clang::Type
667*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
668  if (DD) {
669    clang::QualType T = DD->getType();
670
671    if (T.isNull())
672      return nullptr;
673    else
674      return T.getTypePtr();
675  }
676  return nullptr;
677}
678
679llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
680  T = GetCanonicalType(T);
681  if (T == nullptr)
682    return llvm::StringRef();
683
684  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
685
686  switch (T->getTypeClass()) {
687    case clang::Type::Builtin: {
688      const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
689      BuiltinInfo *info = FindBuiltinType(BT->getKind());
690      if (info != nullptr) {
691        return info->cname[0];
692      }
693      slangAssert(false && "Unknown data type of the builtin");
694      break;
695    }
696    case clang::Type::Record: {
697      clang::RecordDecl *RD;
698      if (T->isStructureType()) {
699        RD = T->getAsStructureType()->getDecl();
700      } else {
701        break;
702      }
703
704      llvm::StringRef Name = RD->getName();
705      if (Name.empty()) {
706        if (RD->getTypedefNameForAnonDecl() != nullptr) {
707          Name = RD->getTypedefNameForAnonDecl()->getName();
708        }
709
710        if (Name.empty()) {
711          // Try to find a name from redeclaration (i.e. typedef)
712          for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
713                   RE = RD->redecls_end();
714               RI != RE;
715               RI++) {
716            slangAssert(*RI != nullptr && "cannot be NULL object");
717
718            Name = (*RI)->getName();
719            if (!Name.empty())
720              break;
721          }
722        }
723      }
724      return Name;
725    }
726    case clang::Type::Pointer: {
727      // "*" plus pointee name
728      const clang::PointerType *P = static_cast<const clang::PointerType*>(CTI);
729      const clang::Type *PT = GetPointeeType(P);
730      llvm::StringRef PointeeName;
731      if (NormalizeType(PT, PointeeName, nullptr, nullptr)) {
732        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
733        Name[0] = '*';
734        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
735        Name[PointeeName.size() + 1] = '\0';
736        return Name;
737      }
738      break;
739    }
740    case clang::Type::ExtVector: {
741      const clang::ExtVectorType *EVT =
742              static_cast<const clang::ExtVectorType*>(CTI);
743      return RSExportVectorType::GetTypeName(EVT);
744      break;
745    }
746    case clang::Type::ConstantArray : {
747      // Construct name for a constant array is too complicated.
748      return "<ConstantArray>";
749    }
750    default: {
751      break;
752    }
753  }
754
755  return llvm::StringRef();
756}
757
758
759RSExportType *RSExportType::Create(RSContext *Context,
760                                   const clang::Type *T,
761                                   const llvm::StringRef &TypeName) {
762  // Lookup the context to see whether the type was processed before.
763  // Newly created RSExportType will insert into context
764  // in RSExportType::RSExportType()
765  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
766
767  if (ETI != Context->export_types_end())
768    return ETI->second;
769
770  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
771
772  RSExportType *ET = nullptr;
773  switch (T->getTypeClass()) {
774    case clang::Type::Record: {
775      DataType dt = RSExportPrimitiveType::GetRSSpecificType(TypeName);
776      switch (dt) {
777        case DataTypeUnknown: {
778          // User-defined types
779          ET = RSExportRecordType::Create(Context,
780                                          T->getAsStructureType(),
781                                          TypeName);
782          break;
783        }
784        case DataTypeRSMatrix2x2: {
785          // 2 x 2 Matrix type
786          ET = RSExportMatrixType::Create(Context,
787                                          T->getAsStructureType(),
788                                          TypeName,
789                                          2);
790          break;
791        }
792        case DataTypeRSMatrix3x3: {
793          // 3 x 3 Matrix type
794          ET = RSExportMatrixType::Create(Context,
795                                          T->getAsStructureType(),
796                                          TypeName,
797                                          3);
798          break;
799        }
800        case DataTypeRSMatrix4x4: {
801          // 4 x 4 Matrix type
802          ET = RSExportMatrixType::Create(Context,
803                                          T->getAsStructureType(),
804                                          TypeName,
805                                          4);
806          break;
807        }
808        default: {
809          // Others are primitive types
810          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
811          break;
812        }
813      }
814      break;
815    }
816    case clang::Type::Builtin: {
817      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
818      break;
819    }
820    case clang::Type::Pointer: {
821      ET = RSExportPointerType::Create(Context,
822                                       static_cast<const clang::PointerType*>(CTI),
823                                       TypeName);
824      // FIXME: free the name (allocated in RSExportType::GetTypeName)
825      delete [] TypeName.data();
826      break;
827    }
828    case clang::Type::ExtVector: {
829      ET = RSExportVectorType::Create(Context,
830                                      static_cast<const clang::ExtVectorType*>(CTI),
831                                      TypeName);
832      break;
833    }
834    case clang::Type::ConstantArray: {
835      ET = RSExportConstantArrayType::Create(
836              Context,
837              static_cast<const clang::ConstantArrayType*>(CTI));
838      break;
839    }
840    default: {
841      Context->ReportError("unknown type cannot be exported: '%0'")
842          << T->getTypeClassName();
843      break;
844    }
845  }
846
847  return ET;
848}
849
850RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
851  llvm::StringRef TypeName;
852  if (NormalizeType(T, TypeName, Context, nullptr)) {
853    return Create(Context, T, TypeName);
854  } else {
855    return nullptr;
856  }
857}
858
859RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
860                                           const clang::VarDecl *VD) {
861  return RSExportType::Create(Context, GetTypeOfDecl(VD));
862}
863
864size_t RSExportType::getStoreSize() const {
865  return getRSContext()->getDataLayout()->getTypeStoreSize(getLLVMType());
866}
867
868size_t RSExportType::getAllocSize() const {
869    return getRSContext()->getDataLayout()->getTypeAllocSize(getLLVMType());
870}
871
872RSExportType::RSExportType(RSContext *Context,
873                           ExportClass Class,
874                           const llvm::StringRef &Name)
875    : RSExportable(Context, RSExportable::EX_TYPE),
876      mClass(Class),
877      // Make a copy on Name since memory stored @Name is either allocated in
878      // ASTContext or allocated in GetTypeName which will be destroyed later.
879      mName(Name.data(), Name.size()),
880      mLLVMType(nullptr) {
881  // Don't cache the type whose name start with '<'. Those type failed to
882  // get their name since constructing their name in GetTypeName() requiring
883  // complicated work.
884  if (!IsDummyName(Name)) {
885    // TODO(zonr): Need to check whether the insertion is successful or not.
886    Context->insertExportType(llvm::StringRef(Name), this);
887  }
888
889}
890
891bool RSExportType::keep() {
892  if (!RSExportable::keep())
893    return false;
894  // Invalidate converted LLVM type.
895  mLLVMType = nullptr;
896  return true;
897}
898
899bool RSExportType::equals(const RSExportable *E) const {
900  CHECK_PARENT_EQUALITY(RSExportable, E);
901  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
902}
903
904RSExportType::~RSExportType() {
905}
906
907/************************** RSExportPrimitiveType **************************/
908llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
909RSExportPrimitiveType::RSSpecificTypeMap;
910
911bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
912  if ((T != nullptr) && (T->getTypeClass() == clang::Type::Builtin))
913    return true;
914  else
915    return false;
916}
917
918DataType
919RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
920  if (TypeName.empty())
921    return DataTypeUnknown;
922
923  if (RSSpecificTypeMap->empty()) {
924    for (int i = 0; i < MatrixAndObjectDataTypesCount; i++) {
925      RSSpecificTypeMap->GetOrCreateValue(MatrixAndObjectDataTypes[i].name,
926                                          MatrixAndObjectDataTypes[i].dataType);
927    }
928  }
929
930  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
931  if (I == RSSpecificTypeMap->end())
932    return DataTypeUnknown;
933  else
934    return I->getValue();
935}
936
937DataType RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
938  T = GetCanonicalType(T);
939  if ((T == nullptr) || (T->getTypeClass() != clang::Type::Record))
940    return DataTypeUnknown;
941
942  return GetRSSpecificType( RSExportType::GetTypeName(T) );
943}
944
945bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
946    if (DT < 0 || DT >= DataTypeMax) {
947        return false;
948    }
949    return gReflectionTypes[DT].category == MatrixDataType;
950}
951
952bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
953    if (DT < 0 || DT >= DataTypeMax) {
954        return false;
955    }
956    return gReflectionTypes[DT].category == ObjectDataType;
957}
958
959bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
960  bool RSObjectTypeSeen = false;
961  while (T && T->isArrayType()) {
962    T = T->getArrayElementTypeNoTypeQual();
963  }
964
965  const clang::RecordType *RT = T->getAsStructureType();
966  if (!RT) {
967    return false;
968  }
969
970  const clang::RecordDecl *RD = RT->getDecl();
971  if (RD) {
972    RD = RD->getDefinition();
973  }
974  if (!RD) {
975    return false;
976  }
977
978  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
979         FE = RD->field_end();
980       FI != FE;
981       FI++) {
982    // We just look through all field declarations to see if we find a
983    // declaration for an RS object type (or an array of one).
984    const clang::FieldDecl *FD = *FI;
985    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
986    while (FT && FT->isArrayType()) {
987      FT = FT->getArrayElementTypeNoTypeQual();
988    }
989
990    DataType DT = GetRSSpecificType(FT);
991    if (IsRSObjectType(DT)) {
992      // RS object types definitely need to be zero-initialized
993      RSObjectTypeSeen = true;
994    } else {
995      switch (DT) {
996        case DataTypeRSMatrix2x2:
997        case DataTypeRSMatrix3x3:
998        case DataTypeRSMatrix4x4:
999          // Matrix types should get zero-initialized as well
1000          RSObjectTypeSeen = true;
1001          break;
1002        default:
1003          // Ignore all other primitive types
1004          break;
1005      }
1006      while (FT && FT->isArrayType()) {
1007        FT = FT->getArrayElementTypeNoTypeQual();
1008      }
1009      if (FT->isStructureType()) {
1010        // Recursively handle structs of structs (even though these can't
1011        // be exported, it is possible for a user to have them internally).
1012        RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
1013      }
1014    }
1015  }
1016
1017  return RSObjectTypeSeen;
1018}
1019
1020size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
1021  int type = EPT->getType();
1022  slangAssert((type > DataTypeUnknown && type < DataTypeMax) &&
1023              "RSExportPrimitiveType::GetSizeInBits : unknown data type");
1024  // All RS object types are 256 bits in 64-bit RS.
1025  if (EPT->isRSObjectType() && EPT->getRSContext()->is64Bit()) {
1026    return 256;
1027  }
1028  return gReflectionTypes[type].size_in_bits;
1029}
1030
1031DataType
1032RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
1033  if (T == nullptr)
1034    return DataTypeUnknown;
1035
1036  switch (T->getTypeClass()) {
1037    case clang::Type::Builtin: {
1038      const clang::BuiltinType *BT =
1039              static_cast<const clang::BuiltinType*>(T->getCanonicalTypeInternal().getTypePtr());
1040      BuiltinInfo *info = FindBuiltinType(BT->getKind());
1041      if (info != nullptr) {
1042        return info->type;
1043      }
1044      // The size of type WChar depend on platform so we abandon the support
1045      // to them.
1046      Context->ReportError("built-in type cannot be exported: '%0'")
1047          << T->getTypeClassName();
1048      break;
1049    }
1050    case clang::Type::Record: {
1051      // must be RS object type
1052      return RSExportPrimitiveType::GetRSSpecificType(T);
1053    }
1054    default: {
1055      Context->ReportError("primitive type cannot be exported: '%0'")
1056          << T->getTypeClassName();
1057      break;
1058    }
1059  }
1060
1061  return DataTypeUnknown;
1062}
1063
1064RSExportPrimitiveType
1065*RSExportPrimitiveType::Create(RSContext *Context,
1066                               const clang::Type *T,
1067                               const llvm::StringRef &TypeName,
1068                               bool Normalized) {
1069  DataType DT = GetDataType(Context, T);
1070
1071  if ((DT == DataTypeUnknown) || TypeName.empty())
1072    return nullptr;
1073  else
1074    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
1075                                     DT, Normalized);
1076}
1077
1078RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
1079                                                     const clang::Type *T) {
1080  llvm::StringRef TypeName;
1081  if (RSExportType::NormalizeType(T, TypeName, Context, nullptr)
1082      && IsPrimitiveType(T)) {
1083    return Create(Context, T, TypeName);
1084  } else {
1085    return nullptr;
1086  }
1087}
1088
1089llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
1090  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1091
1092  if (isRSObjectType()) {
1093    // struct {
1094    //   int *p;
1095    // } __attribute__((packed, aligned(pointer_size)))
1096    //
1097    // which is
1098    //
1099    // <{ [1 x i32] }> in LLVM
1100    //
1101    std::vector<llvm::Type *> Elements;
1102    if (getRSContext()->is64Bit()) {
1103      // 64-bit path
1104      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt64Ty(C), 4));
1105      return llvm::StructType::get(C, Elements, true);
1106    } else {
1107      // 32-bit legacy path
1108      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
1109      return llvm::StructType::get(C, Elements, true);
1110    }
1111  }
1112
1113  switch (mType) {
1114    case DataTypeFloat32: {
1115      return llvm::Type::getFloatTy(C);
1116      break;
1117    }
1118    case DataTypeFloat64: {
1119      return llvm::Type::getDoubleTy(C);
1120      break;
1121    }
1122    case DataTypeBoolean: {
1123      return llvm::Type::getInt1Ty(C);
1124      break;
1125    }
1126    case DataTypeSigned8:
1127    case DataTypeUnsigned8: {
1128      return llvm::Type::getInt8Ty(C);
1129      break;
1130    }
1131    case DataTypeSigned16:
1132    case DataTypeUnsigned16:
1133    case DataTypeUnsigned565:
1134    case DataTypeUnsigned5551:
1135    case DataTypeUnsigned4444: {
1136      return llvm::Type::getInt16Ty(C);
1137      break;
1138    }
1139    case DataTypeSigned32:
1140    case DataTypeUnsigned32: {
1141      return llvm::Type::getInt32Ty(C);
1142      break;
1143    }
1144    case DataTypeSigned64:
1145    case DataTypeUnsigned64: {
1146      return llvm::Type::getInt64Ty(C);
1147      break;
1148    }
1149    default: {
1150      slangAssert(false && "Unknown data type");
1151    }
1152  }
1153
1154  return nullptr;
1155}
1156
1157bool RSExportPrimitiveType::equals(const RSExportable *E) const {
1158  CHECK_PARENT_EQUALITY(RSExportType, E);
1159  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
1160}
1161
1162RSReflectionType *RSExportPrimitiveType::getRSReflectionType(DataType DT) {
1163  if (DT > DataTypeUnknown && DT < DataTypeMax) {
1164    return &gReflectionTypes[DT];
1165  } else {
1166    return nullptr;
1167  }
1168}
1169
1170/**************************** RSExportPointerType ****************************/
1171
1172RSExportPointerType
1173*RSExportPointerType::Create(RSContext *Context,
1174                             const clang::PointerType *PT,
1175                             const llvm::StringRef &TypeName) {
1176  const clang::Type *PointeeType = GetPointeeType(PT);
1177  const RSExportType *PointeeET;
1178
1179  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
1180    PointeeET = RSExportType::Create(Context, PointeeType);
1181  } else {
1182    // Double or higher dimension of pointer, export as int*
1183    PointeeET = RSExportPrimitiveType::Create(Context,
1184                    Context->getASTContext().IntTy.getTypePtr());
1185  }
1186
1187  if (PointeeET == nullptr) {
1188    // Error diagnostic is emitted for corresponding pointee type
1189    return nullptr;
1190  }
1191
1192  return new RSExportPointerType(Context, TypeName, PointeeET);
1193}
1194
1195llvm::Type *RSExportPointerType::convertToLLVMType() const {
1196  llvm::Type *PointeeType = mPointeeType->getLLVMType();
1197  return llvm::PointerType::getUnqual(PointeeType);
1198}
1199
1200bool RSExportPointerType::keep() {
1201  if (!RSExportType::keep())
1202    return false;
1203  const_cast<RSExportType*>(mPointeeType)->keep();
1204  return true;
1205}
1206
1207bool RSExportPointerType::equals(const RSExportable *E) const {
1208  CHECK_PARENT_EQUALITY(RSExportType, E);
1209  return (static_cast<const RSExportPointerType*>(E)
1210              ->getPointeeType()->equals(getPointeeType()));
1211}
1212
1213/***************************** RSExportVectorType *****************************/
1214llvm::StringRef
1215RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
1216  const clang::Type *ElementType = GetExtVectorElementType(EVT);
1217  llvm::StringRef name;
1218
1219  if ((ElementType->getTypeClass() != clang::Type::Builtin))
1220    return name;
1221
1222  const clang::BuiltinType *BT =
1223          static_cast<const clang::BuiltinType*>(
1224              ElementType->getCanonicalTypeInternal().getTypePtr());
1225
1226  if ((EVT->getNumElements() < 1) ||
1227      (EVT->getNumElements() > 4))
1228    return name;
1229
1230  BuiltinInfo *info = FindBuiltinType(BT->getKind());
1231  if (info != nullptr) {
1232    int I = EVT->getNumElements() - 1;
1233    if (I < kMaxVectorSize) {
1234      name = info->cname[I];
1235    } else {
1236      slangAssert(false && "Max vector is 4");
1237    }
1238  }
1239  return name;
1240}
1241
1242RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
1243                                               const clang::ExtVectorType *EVT,
1244                                               const llvm::StringRef &TypeName,
1245                                               bool Normalized) {
1246  slangAssert(EVT != nullptr && EVT->getTypeClass() == clang::Type::ExtVector);
1247
1248  const clang::Type *ElementType = GetExtVectorElementType(EVT);
1249  DataType DT = RSExportPrimitiveType::GetDataType(Context, ElementType);
1250
1251  if (DT != DataTypeUnknown)
1252    return new RSExportVectorType(Context,
1253                                  TypeName,
1254                                  DT,
1255                                  Normalized,
1256                                  EVT->getNumElements());
1257  else
1258    return nullptr;
1259}
1260
1261llvm::Type *RSExportVectorType::convertToLLVMType() const {
1262  llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
1263  return llvm::VectorType::get(ElementType, getNumElement());
1264}
1265
1266bool RSExportVectorType::equals(const RSExportable *E) const {
1267  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
1268  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
1269              == getNumElement());
1270}
1271
1272/***************************** RSExportMatrixType *****************************/
1273RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
1274                                               const clang::RecordType *RT,
1275                                               const llvm::StringRef &TypeName,
1276                                               unsigned Dim) {
1277  slangAssert((RT != nullptr) && (RT->getTypeClass() == clang::Type::Record));
1278  slangAssert((Dim > 1) && "Invalid dimension of matrix");
1279
1280  // Check whether the struct rs_matrix is in our expected form (but assume it's
1281  // correct if we're not sure whether it's correct or not)
1282  const clang::RecordDecl* RD = RT->getDecl();
1283  RD = RD->getDefinition();
1284  if (RD != nullptr) {
1285    // Find definition, perform further examination
1286    if (RD->field_empty()) {
1287      Context->ReportError(
1288          RD->getLocation(),
1289          "invalid matrix struct: must have 1 field for saving values: '%0'")
1290          << RD->getName();
1291      return nullptr;
1292    }
1293
1294    clang::RecordDecl::field_iterator FIT = RD->field_begin();
1295    const clang::FieldDecl *FD = *FIT;
1296    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1297    if ((FT == nullptr) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
1298      Context->ReportError(RD->getLocation(),
1299                           "invalid matrix struct: first field should"
1300                           " be an array with constant size: '%0'")
1301          << RD->getName();
1302      return nullptr;
1303    }
1304    const clang::ConstantArrayType *CAT =
1305      static_cast<const clang::ConstantArrayType *>(FT);
1306    const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1307    if ((ElementType == nullptr) ||
1308        (ElementType->getTypeClass() != clang::Type::Builtin) ||
1309        (static_cast<const clang::BuiltinType *>(ElementType)->getKind() !=
1310         clang::BuiltinType::Float)) {
1311      Context->ReportError(RD->getLocation(),
1312                           "invalid matrix struct: first field "
1313                           "should be a float array: '%0'")
1314          << RD->getName();
1315      return nullptr;
1316    }
1317
1318    if (CAT->getSize() != Dim * Dim) {
1319      Context->ReportError(RD->getLocation(),
1320                           "invalid matrix struct: first field "
1321                           "should be an array with size %0: '%1'")
1322          << (Dim * Dim) << (RD->getName());
1323      return nullptr;
1324    }
1325
1326    FIT++;
1327    if (FIT != RD->field_end()) {
1328      Context->ReportError(RD->getLocation(),
1329                           "invalid matrix struct: must have "
1330                           "exactly 1 field: '%0'")
1331          << RD->getName();
1332      return nullptr;
1333    }
1334  }
1335
1336  return new RSExportMatrixType(Context, TypeName, Dim);
1337}
1338
1339llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1340  // Construct LLVM type:
1341  // struct {
1342  //  float X[mDim * mDim];
1343  // }
1344
1345  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1346  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1347                                            mDim * mDim);
1348  return llvm::StructType::get(C, X, false);
1349}
1350
1351bool RSExportMatrixType::equals(const RSExportable *E) const {
1352  CHECK_PARENT_EQUALITY(RSExportType, E);
1353  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1354}
1355
1356/************************* RSExportConstantArrayType *************************/
1357RSExportConstantArrayType
1358*RSExportConstantArrayType::Create(RSContext *Context,
1359                                   const clang::ConstantArrayType *CAT) {
1360  slangAssert(CAT != nullptr && CAT->getTypeClass() == clang::Type::ConstantArray);
1361
1362  slangAssert((CAT->getSize().getActiveBits() < 32) && "array too large");
1363
1364  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1365  slangAssert((Size > 0) && "Constant array should have size greater than 0");
1366
1367  const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1368  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
1369
1370  if (ElementET == nullptr) {
1371    return nullptr;
1372  }
1373
1374  return new RSExportConstantArrayType(Context,
1375                                       ElementET,
1376                                       Size);
1377}
1378
1379llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1380  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
1381}
1382
1383bool RSExportConstantArrayType::keep() {
1384  if (!RSExportType::keep())
1385    return false;
1386  const_cast<RSExportType*>(mElementType)->keep();
1387  return true;
1388}
1389
1390bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1391  CHECK_PARENT_EQUALITY(RSExportType, E);
1392  const RSExportConstantArrayType *RHS =
1393      static_cast<const RSExportConstantArrayType*>(E);
1394  return ((getSize() == RHS->getSize()) &&
1395          (getElementType()->equals(RHS->getElementType())));
1396}
1397
1398/**************************** RSExportRecordType ****************************/
1399RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1400                                               const clang::RecordType *RT,
1401                                               const llvm::StringRef &TypeName,
1402                                               bool mIsArtificial) {
1403  slangAssert(RT != nullptr && RT->getTypeClass() == clang::Type::Record);
1404
1405  const clang::RecordDecl *RD = RT->getDecl();
1406  slangAssert(RD->isStruct());
1407
1408  RD = RD->getDefinition();
1409  if (RD == nullptr) {
1410    slangAssert(false && "struct is not defined in this module");
1411    return nullptr;
1412  }
1413
1414  // Struct layout construct by clang. We rely on this for obtaining the
1415  // alloc size of a struct and offset of every field in that struct.
1416  const clang::ASTRecordLayout *RL =
1417      &Context->getASTContext().getASTRecordLayout(RD);
1418  slangAssert((RL != nullptr) &&
1419      "Failed to retrieve the struct layout from Clang.");
1420
1421  RSExportRecordType *ERT =
1422      new RSExportRecordType(Context,
1423                             TypeName,
1424                             RD->hasAttr<clang::PackedAttr>(),
1425                             mIsArtificial,
1426                             RL->getDataSize().getQuantity(),
1427                             RL->getSize().getQuantity());
1428  unsigned int Index = 0;
1429
1430  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1431           FE = RD->field_end();
1432       FI != FE;
1433       FI++, Index++) {
1434
1435    // FIXME: All fields should be primitive type
1436    slangAssert(FI->getKind() == clang::Decl::Field);
1437    clang::FieldDecl *FD = *FI;
1438
1439    if (FD->isBitField()) {
1440      return nullptr;
1441    }
1442
1443    // Type
1444    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1445
1446    if (ET != nullptr) {
1447      ERT->mFields.push_back(
1448          new Field(ET, FD->getName(), ERT,
1449                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1450    } else {
1451      Context->ReportError(RD->getLocation(),
1452                           "field type cannot be exported: '%0.%1'")
1453          << RD->getName() << FD->getName();
1454      return nullptr;
1455    }
1456  }
1457
1458  return ERT;
1459}
1460
1461llvm::Type *RSExportRecordType::convertToLLVMType() const {
1462  // Create an opaque type since struct may reference itself recursively.
1463
1464  // TODO(sliao): LLVM took out the OpaqueType. Any other to migrate to?
1465  std::vector<llvm::Type*> FieldTypes;
1466
1467  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1468       FI != FE;
1469       FI++) {
1470    const Field *F = *FI;
1471    const RSExportType *FET = F->getType();
1472
1473    FieldTypes.push_back(FET->getLLVMType());
1474  }
1475
1476  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1477                                               FieldTypes,
1478                                               mIsPacked);
1479  if (ST != nullptr) {
1480    return ST;
1481  } else {
1482    return nullptr;
1483  }
1484}
1485
1486bool RSExportRecordType::keep() {
1487  if (!RSExportType::keep())
1488    return false;
1489  for (std::list<const Field*>::iterator I = mFields.begin(),
1490          E = mFields.end();
1491       I != E;
1492       I++) {
1493    const_cast<RSExportType*>((*I)->getType())->keep();
1494  }
1495  return true;
1496}
1497
1498bool RSExportRecordType::equals(const RSExportable *E) const {
1499  CHECK_PARENT_EQUALITY(RSExportType, E);
1500
1501  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1502
1503  if (ERT->getFields().size() != getFields().size())
1504    return false;
1505
1506  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1507
1508  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1509    if (!(*AI)->getType()->equals((*BI)->getType()))
1510      return false;
1511    AI++;
1512    BI++;
1513  }
1514
1515  return true;
1516}
1517
1518void RSExportType::convertToRTD(RSReflectionTypeData *rtd) const {
1519    memset(rtd, 0, sizeof(*rtd));
1520    rtd->vecSize = 1;
1521
1522    switch(getClass()) {
1523    case RSExportType::ExportClassPrimitive: {
1524            const RSExportPrimitiveType *EPT = static_cast<const RSExportPrimitiveType*>(this);
1525            rtd->type = RSExportPrimitiveType::getRSReflectionType(EPT);
1526            return;
1527        }
1528    case RSExportType::ExportClassPointer: {
1529            const RSExportPointerType *EPT = static_cast<const RSExportPointerType*>(this);
1530            const RSExportType *PointeeType = EPT->getPointeeType();
1531            PointeeType->convertToRTD(rtd);
1532            rtd->isPointer = true;
1533            return;
1534        }
1535    case RSExportType::ExportClassVector: {
1536            const RSExportVectorType *EVT = static_cast<const RSExportVectorType*>(this);
1537            rtd->type = EVT->getRSReflectionType(EVT);
1538            rtd->vecSize = EVT->getNumElement();
1539            return;
1540        }
1541    case RSExportType::ExportClassMatrix: {
1542            const RSExportMatrixType *EMT = static_cast<const RSExportMatrixType*>(this);
1543            unsigned Dim = EMT->getDim();
1544            slangAssert((Dim >= 2) && (Dim <= 4));
1545            rtd->type = &gReflectionTypes[15 + Dim-2];
1546            return;
1547        }
1548    case RSExportType::ExportClassConstantArray: {
1549            const RSExportConstantArrayType* CAT =
1550              static_cast<const RSExportConstantArrayType*>(this);
1551            CAT->getElementType()->convertToRTD(rtd);
1552            rtd->arraySize = CAT->getSize();
1553            return;
1554        }
1555    case RSExportType::ExportClassRecord: {
1556            slangAssert(!"RSExportType::ExportClassRecord not implemented");
1557            return;// RS_TYPE_CLASS_NAME_PREFIX + ET->getName() + ".Item";
1558        }
1559    default: {
1560            slangAssert(false && "Unknown class of type");
1561        }
1562    }
1563}
1564
1565
1566}  // namespace slang
1567