slang_rs_export_type.cpp revision e4dd17d7b2a292a600756da7680beecd78f74033
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/ASTContext.h"
23#include "clang/AST/Attr.h"
24#include "clang/AST/RecordLayout.h"
25
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/DerivedTypes.h"
29#include "llvm/IR/Type.h"
30
31#include "slang_assert.h"
32#include "slang_rs_context.h"
33#include "slang_rs_export_element.h"
34#include "slang_version.h"
35
36#define CHECK_PARENT_EQUALITY(ParentClass, E) \
37  if (!ParentClass::equals(E))                \
38    return false;
39
40namespace slang {
41
42namespace {
43
44/* For the data types we support, their category, names, and size (in bits).
45 *
46 * IMPORTANT: The data types in this table should be at the same index
47 * as specified by the corresponding DataType enum.
48 */
49static RSReflectionType gReflectionTypes[] = {
50    {PrimitiveDataType, "FLOAT_16", "F16", 16, "half", "short", "Half", "Half", false},
51    {PrimitiveDataType, "FLOAT_32", "F32", 32, "float", "float", "Float", "Float", false},
52    {PrimitiveDataType, "FLOAT_64", "F64", 64, "double", "double", "Double", "Double",false},
53    {PrimitiveDataType, "SIGNED_8", "I8", 8, "int8_t", "byte", "Byte", "Byte", false},
54    {PrimitiveDataType, "SIGNED_16", "I16", 16, "int16_t", "short", "Short", "Short", false},
55    {PrimitiveDataType, "SIGNED_32", "I32", 32, "int32_t", "int", "Int", "Int", false},
56    {PrimitiveDataType, "SIGNED_64", "I64", 64, "int64_t", "long", "Long", "Long", false},
57    {PrimitiveDataType, "UNSIGNED_8", "U8", 8, "uint8_t", "short", "UByte", "Short", true},
58    {PrimitiveDataType, "UNSIGNED_16", "U16", 16, "uint16_t", "int", "UShort", "Int", true},
59    {PrimitiveDataType, "UNSIGNED_32", "U32", 32, "uint32_t", "long", "UInt", "Long", true},
60    {PrimitiveDataType, "UNSIGNED_64", "U64", 64, "uint64_t", "long", "ULong", "Long", false},
61
62    {PrimitiveDataType, "BOOLEAN", "BOOLEAN", 8, "bool", "boolean", nullptr, nullptr, false},
63
64    {PrimitiveDataType, "UNSIGNED_5_6_5", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
65    {PrimitiveDataType, "UNSIGNED_5_5_5_1", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
66    {PrimitiveDataType, "UNSIGNED_4_4_4_4", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
67
68    {MatrixDataType, "MATRIX_2X2", nullptr, 4*32, "rsMatrix_2x2", "Matrix2f", nullptr, nullptr, false},
69    {MatrixDataType, "MATRIX_3X3", nullptr, 9*32, "rsMatrix_3x3", "Matrix3f", nullptr, nullptr, false},
70    {MatrixDataType, "MATRIX_4X4", nullptr, 16*32, "rsMatrix_4x4", "Matrix4f", nullptr, nullptr, false},
71
72    // RS object types are 32 bits in 32-bit RS, but 256 bits in 64-bit RS.
73    // This is handled specially by the GetSizeInBits() method.
74    {ObjectDataType, "RS_ELEMENT", "ELEMENT", 32, "Element", "Element", nullptr, nullptr, false},
75    {ObjectDataType, "RS_TYPE", "TYPE", 32, "Type", "Type", nullptr, nullptr, false},
76    {ObjectDataType, "RS_ALLOCATION", "ALLOCATION", 32, "Allocation", "Allocation", nullptr, nullptr, false},
77    {ObjectDataType, "RS_SAMPLER", "SAMPLER", 32, "Sampler", "Sampler", nullptr, nullptr, false},
78    {ObjectDataType, "RS_SCRIPT", "SCRIPT", 32, "Script", "Script", nullptr, nullptr, false},
79    {ObjectDataType, "RS_MESH", "MESH", 32, "Mesh", "Mesh", nullptr, nullptr, false},
80    {ObjectDataType, "RS_PATH", "PATH", 32, "Path", "Path", nullptr, nullptr, false},
81
82    {ObjectDataType, "RS_PROGRAM_FRAGMENT", "PROGRAM_FRAGMENT", 32, "ProgramFragment", "ProgramFragment", nullptr, nullptr, false},
83    {ObjectDataType, "RS_PROGRAM_VERTEX", "PROGRAM_VERTEX", 32, "ProgramVertex", "ProgramVertex", nullptr, nullptr, false},
84    {ObjectDataType, "RS_PROGRAM_RASTER", "PROGRAM_RASTER", 32, "ProgramRaster", "ProgramRaster", nullptr, nullptr, false},
85    {ObjectDataType, "RS_PROGRAM_STORE", "PROGRAM_STORE", 32, "ProgramStore", "ProgramStore", nullptr, nullptr, false},
86    {ObjectDataType, "RS_FONT", "FONT", 32, "Font", "Font", nullptr, nullptr, false}
87};
88
89const int kMaxVectorSize = 4;
90
91struct BuiltinInfo {
92  clang::BuiltinType::Kind builtinTypeKind;
93  DataType type;
94  /* TODO If we return std::string instead of llvm::StringRef, we could build
95   * the name instead of duplicating the entries.
96   */
97  const char *cname[kMaxVectorSize];
98};
99
100
101BuiltinInfo BuiltinInfoTable[] = {
102    {clang::BuiltinType::Bool, DataTypeBoolean,
103     {"bool", "bool2", "bool3", "bool4"}},
104    {clang::BuiltinType::Char_U, DataTypeUnsigned8,
105     {"uchar", "uchar2", "uchar3", "uchar4"}},
106    {clang::BuiltinType::UChar, DataTypeUnsigned8,
107     {"uchar", "uchar2", "uchar3", "uchar4"}},
108    {clang::BuiltinType::Char16, DataTypeSigned16,
109     {"short", "short2", "short3", "short4"}},
110    {clang::BuiltinType::Char32, DataTypeSigned32,
111     {"int", "int2", "int3", "int4"}},
112    {clang::BuiltinType::UShort, DataTypeUnsigned16,
113     {"ushort", "ushort2", "ushort3", "ushort4"}},
114    {clang::BuiltinType::UInt, DataTypeUnsigned32,
115     {"uint", "uint2", "uint3", "uint4"}},
116    {clang::BuiltinType::ULong, DataTypeUnsigned64,
117     {"ulong", "ulong2", "ulong3", "ulong4"}},
118    {clang::BuiltinType::ULongLong, DataTypeUnsigned64,
119     {"ulong", "ulong2", "ulong3", "ulong4"}},
120
121    {clang::BuiltinType::Char_S, DataTypeSigned8,
122     {"char", "char2", "char3", "char4"}},
123    {clang::BuiltinType::SChar, DataTypeSigned8,
124     {"char", "char2", "char3", "char4"}},
125    {clang::BuiltinType::Short, DataTypeSigned16,
126     {"short", "short2", "short3", "short4"}},
127    {clang::BuiltinType::Int, DataTypeSigned32,
128     {"int", "int2", "int3", "int4"}},
129    {clang::BuiltinType::Long, DataTypeSigned64,
130     {"long", "long2", "long3", "long4"}},
131    {clang::BuiltinType::LongLong, DataTypeSigned64,
132     {"long", "long2", "long3", "long4"}},
133    {clang::BuiltinType::Half, DataTypeFloat16,
134     {"half", "half2", "half3", "half4"}},
135    {clang::BuiltinType::Float, DataTypeFloat32,
136     {"float", "float2", "float3", "float4"}},
137    {clang::BuiltinType::Double, DataTypeFloat64,
138     {"double", "double2", "double3", "double4"}},
139};
140const int BuiltinInfoTableCount = sizeof(BuiltinInfoTable) / sizeof(BuiltinInfoTable[0]);
141
142struct NameAndPrimitiveType {
143  const char *name;
144  DataType dataType;
145};
146
147static NameAndPrimitiveType MatrixAndObjectDataTypes[] = {
148    {"rs_matrix2x2", DataTypeRSMatrix2x2},
149    {"rs_matrix3x3", DataTypeRSMatrix3x3},
150    {"rs_matrix4x4", DataTypeRSMatrix4x4},
151    {"rs_element", DataTypeRSElement},
152    {"rs_type", DataTypeRSType},
153    {"rs_allocation", DataTypeRSAllocation},
154    {"rs_sampler", DataTypeRSSampler},
155    {"rs_script", DataTypeRSScript},
156    {"rs_mesh", DataTypeRSMesh},
157    {"rs_path", DataTypeRSPath},
158    {"rs_program_fragment", DataTypeRSProgramFragment},
159    {"rs_program_vertex", DataTypeRSProgramVertex},
160    {"rs_program_raster", DataTypeRSProgramRaster},
161    {"rs_program_store", DataTypeRSProgramStore},
162    {"rs_font", DataTypeRSFont},
163};
164
165const int MatrixAndObjectDataTypesCount =
166    sizeof(MatrixAndObjectDataTypes) / sizeof(MatrixAndObjectDataTypes[0]);
167
168static const clang::Type *TypeExportableHelper(
169    const clang::Type *T,
170    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
171    slang::RSContext *Context,
172    const clang::VarDecl *VD,
173    const clang::RecordDecl *TopLevelRecord);
174
175template <unsigned N>
176static void ReportTypeError(slang::RSContext *Context,
177                            const clang::NamedDecl *ND,
178                            const clang::RecordDecl *TopLevelRecord,
179                            const char (&Message)[N],
180                            unsigned int TargetAPI = 0) {
181  // Attempt to use the type declaration first (if we have one).
182  // Fall back to the variable definition, if we are looking at something
183  // like an array declaration that can't be exported.
184  if (TopLevelRecord) {
185    Context->ReportError(TopLevelRecord->getLocation(), Message)
186        << TopLevelRecord->getName() << TargetAPI;
187  } else if (ND) {
188    Context->ReportError(ND->getLocation(), Message) << ND->getName()
189                                                     << TargetAPI;
190  } else {
191    slangAssert(false && "Variables should be validated before exporting");
192  }
193}
194
195static const clang::Type *ConstantArrayTypeExportableHelper(
196    const clang::ConstantArrayType *CAT,
197    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
198    slang::RSContext *Context,
199    const clang::VarDecl *VD,
200    const clang::RecordDecl *TopLevelRecord) {
201  // Check element type
202  const clang::Type *ElementType = GetConstantArrayElementType(CAT);
203  if (ElementType->isArrayType()) {
204    ReportTypeError(Context, VD, TopLevelRecord,
205                    "multidimensional arrays cannot be exported: '%0'");
206    return nullptr;
207  } else if (ElementType->isExtVectorType()) {
208    const clang::ExtVectorType *EVT =
209        static_cast<const clang::ExtVectorType*>(ElementType);
210    unsigned numElements = EVT->getNumElements();
211
212    const clang::Type *BaseElementType = GetExtVectorElementType(EVT);
213    if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
214      ReportTypeError(Context, VD, TopLevelRecord,
215        "vectors of non-primitive types cannot be exported: '%0'");
216      return nullptr;
217    }
218
219    if (numElements == 3 && CAT->getSize() != 1) {
220      ReportTypeError(Context, VD, TopLevelRecord,
221        "arrays of width 3 vector types cannot be exported: '%0'");
222      return nullptr;
223    }
224  }
225
226  if (TypeExportableHelper(ElementType, SPS, Context, VD,
227                           TopLevelRecord) == nullptr) {
228    return nullptr;
229  } else {
230    return CAT;
231  }
232}
233
234BuiltinInfo *FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind) {
235  for (int i = 0; i < BuiltinInfoTableCount; i++) {
236    if (builtinTypeKind == BuiltinInfoTable[i].builtinTypeKind) {
237      return &BuiltinInfoTable[i];
238    }
239  }
240  return nullptr;
241}
242
243static const clang::Type *TypeExportableHelper(
244    clang::Type const *T,
245    llvm::SmallPtrSet<clang::Type const *, 8> &SPS,
246    slang::RSContext *Context,
247    clang::VarDecl const *VD,
248    clang::RecordDecl const *TopLevelRecord) {
249  // Normalize first
250  if ((T = GetCanonicalType(T)) == nullptr)
251    return nullptr;
252
253  if (SPS.count(T))
254    return T;
255
256  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
257
258  switch (T->getTypeClass()) {
259    case clang::Type::Builtin: {
260      const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
261      return FindBuiltinType(BT->getKind()) == nullptr ? nullptr : T;
262    }
263    case clang::Type::Record: {
264      if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
265        return T;  // RS object type, no further checks are needed
266      }
267
268      // Check internal struct
269      if (T->isUnionType()) {
270        ReportTypeError(Context, VD, T->getAsUnionType()->getDecl(),
271                        "unions cannot be exported: '%0'");
272        return nullptr;
273      } else if (!T->isStructureType()) {
274        slangAssert(false && "Unknown type cannot be exported");
275        return nullptr;
276      }
277
278      clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
279      if (RD != nullptr) {
280        RD = RD->getDefinition();
281        if (RD == nullptr) {
282          ReportTypeError(Context, nullptr, T->getAsStructureType()->getDecl(),
283                          "struct is not defined in this module");
284          return nullptr;
285        }
286      }
287
288      if (!TopLevelRecord) {
289        TopLevelRecord = RD;
290      }
291      if (RD->getName().empty()) {
292        ReportTypeError(Context, nullptr, RD,
293                        "anonymous structures cannot be exported");
294        return nullptr;
295      }
296
297      // Fast check
298      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
299        return nullptr;
300
301      // Insert myself into checking set
302      SPS.insert(T);
303
304      // Check all element
305      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
306               FE = RD->field_end();
307           FI != FE;
308           FI++) {
309        const clang::FieldDecl *FD = *FI;
310        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
311        FT = GetCanonicalType(FT);
312
313        if (!TypeExportableHelper(FT, SPS, Context, VD, TopLevelRecord)) {
314          return nullptr;
315        }
316
317        // We don't support bit fields yet
318        //
319        // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
320        if (FD->isBitField()) {
321          Context->ReportError(
322              FD->getLocation(),
323              "bit fields are not able to be exported: '%0.%1'")
324              << RD->getName() << FD->getName();
325          return nullptr;
326        }
327      }
328
329      return T;
330    }
331    case clang::Type::Pointer: {
332      if (TopLevelRecord) {
333        ReportTypeError(Context, VD, TopLevelRecord,
334            "structures containing pointers cannot be exported: '%0'");
335        return nullptr;
336      }
337
338      const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
339      const clang::Type *PointeeType = GetPointeeType(PT);
340
341      if (PointeeType->getTypeClass() == clang::Type::Pointer) {
342        ReportTypeError(Context, VD, TopLevelRecord,
343            "multiple levels of pointers cannot be exported: '%0'");
344        return nullptr;
345      }
346      // We don't support pointer with array-type pointee or unsupported pointee
347      // type
348      if (PointeeType->isArrayType() ||
349          (TypeExportableHelper(PointeeType, SPS, Context, VD,
350                                TopLevelRecord) == nullptr))
351        return nullptr;
352      else
353        return T;
354    }
355    case clang::Type::ExtVector: {
356      const clang::ExtVectorType *EVT =
357              static_cast<const clang::ExtVectorType*>(CTI);
358      // Only vector with size 2, 3 and 4 are supported.
359      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
360        return nullptr;
361
362      // Check base element type
363      const clang::Type *ElementType = GetExtVectorElementType(EVT);
364
365      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
366          (TypeExportableHelper(ElementType, SPS, Context, VD,
367                                TopLevelRecord) == nullptr))
368        return nullptr;
369      else
370        return T;
371    }
372    case clang::Type::ConstantArray: {
373      const clang::ConstantArrayType *CAT =
374              static_cast<const clang::ConstantArrayType*>(CTI);
375
376      return ConstantArrayTypeExportableHelper(CAT, SPS, Context, VD,
377                                               TopLevelRecord);
378    }
379    case clang::Type::Enum: {
380      // FIXME: We currently convert enums to integers, rather than reflecting
381      // a more complete (and nicer type-safe Java version).
382      return Context->getASTContext().IntTy.getTypePtr();
383    }
384    default: {
385      slangAssert(false && "Unknown type cannot be validated");
386      return nullptr;
387    }
388  }
389}
390
391// Return the type that can be used to create RSExportType, will always return
392// the canonical type.
393//
394// If the Type T is not exportable, this function returns nullptr. DiagEngine is
395// used to generate proper Clang diagnostic messages when a non-exportable type
396// is detected. TopLevelRecord is used to capture the highest struct (in the
397// case of a nested hierarchy) for detecting other types that cannot be exported
398// (mostly pointers within a struct).
399static const clang::Type *TypeExportable(const clang::Type *T,
400                                         slang::RSContext *Context,
401                                         const clang::VarDecl *VD) {
402  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
403      llvm::SmallPtrSet<const clang::Type*, 8>();
404
405  return TypeExportableHelper(T, SPS, Context, VD, nullptr);
406}
407
408static bool ValidateRSObjectInVarDecl(slang::RSContext *Context,
409                                      clang::VarDecl *VD, bool InCompositeType,
410                                      unsigned int TargetAPI) {
411  if (TargetAPI < SLANG_JB_TARGET_API) {
412    // Only if we are already in a composite type (like an array or structure).
413    if (InCompositeType) {
414      // Only if we are actually exported (i.e. non-static).
415      if (VD->hasLinkage() &&
416          (VD->getFormalLinkage() == clang::ExternalLinkage)) {
417        // Only if we are not a pointer to an object.
418        const clang::Type *T = GetCanonicalType(VD->getType().getTypePtr());
419        if (T->getTypeClass() != clang::Type::Pointer) {
420          ReportTypeError(Context, VD, nullptr,
421                          "arrays/structures containing RS object types "
422                          "cannot be exported in target API < %1: '%0'",
423                          SLANG_JB_TARGET_API);
424          return false;
425        }
426      }
427    }
428  }
429
430  return true;
431}
432
433// Helper function for ValidateType(). We do a recursive descent on the
434// type hierarchy to ensure that we can properly export/handle the
435// declaration.
436// \return true if the variable declaration is valid,
437//         false if it is invalid (along with proper diagnostics).
438//
439// C - ASTContext (for diagnostics + builtin types).
440// T - sub-type that we are validating.
441// ND - (optional) top-level named declaration that we are validating.
442// SPS - set of types we have already seen/validated.
443// InCompositeType - true if we are within an outer composite type.
444// UnionDecl - set if we are in a sub-type of a union.
445// TargetAPI - target SDK API level.
446// IsFilterscript - whether or not we are compiling for Filterscript
447static bool ValidateTypeHelper(
448    slang::RSContext *Context,
449    clang::ASTContext &C,
450    const clang::Type *&T,
451    clang::NamedDecl *ND,
452    clang::SourceLocation Loc,
453    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
454    bool InCompositeType,
455    clang::RecordDecl *UnionDecl,
456    unsigned int TargetAPI,
457    bool IsFilterscript) {
458  if ((T = GetCanonicalType(T)) == nullptr)
459    return true;
460
461  if (SPS.count(T))
462    return true;
463
464  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
465
466  switch (T->getTypeClass()) {
467    case clang::Type::Record: {
468      if (RSExportPrimitiveType::IsRSObjectType(T)) {
469        clang::VarDecl *VD = (ND ? llvm::dyn_cast<clang::VarDecl>(ND) : nullptr);
470        if (VD && !ValidateRSObjectInVarDecl(Context, VD, InCompositeType,
471                                             TargetAPI)) {
472          return false;
473        }
474      }
475
476      if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
477        if (!UnionDecl) {
478          return true;
479        } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
480          ReportTypeError(Context, nullptr, UnionDecl,
481              "unions containing RS object types are not allowed");
482          return false;
483        }
484      }
485
486      clang::RecordDecl *RD = nullptr;
487
488      // Check internal struct
489      if (T->isUnionType()) {
490        RD = T->getAsUnionType()->getDecl();
491        UnionDecl = RD;
492      } else if (T->isStructureType()) {
493        RD = T->getAsStructureType()->getDecl();
494      } else {
495        slangAssert(false && "Unknown type cannot be exported");
496        return false;
497      }
498
499      if (RD != nullptr) {
500        RD = RD->getDefinition();
501        if (RD == nullptr) {
502          // FIXME
503          return true;
504        }
505      }
506
507      // Fast check
508      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
509        return false;
510
511      // Insert myself into checking set
512      SPS.insert(T);
513
514      // Check all elements
515      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
516               FE = RD->field_end();
517           FI != FE;
518           FI++) {
519        const clang::FieldDecl *FD = *FI;
520        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
521        FT = GetCanonicalType(FT);
522
523        if (!ValidateTypeHelper(Context, C, FT, ND, Loc, SPS, true, UnionDecl,
524                                TargetAPI, IsFilterscript)) {
525          return false;
526        }
527      }
528
529      return true;
530    }
531
532    case clang::Type::Builtin: {
533      if (IsFilterscript) {
534        clang::QualType QT = T->getCanonicalTypeInternal();
535        if (QT == C.DoubleTy ||
536            QT == C.LongDoubleTy ||
537            QT == C.LongTy ||
538            QT == C.LongLongTy) {
539          if (ND) {
540            Context->ReportError(
541                Loc,
542                "Builtin types > 32 bits in size are forbidden in "
543                "Filterscript: '%0'")
544                << ND->getName();
545          } else {
546            Context->ReportError(
547                Loc,
548                "Builtin types > 32 bits in size are forbidden in "
549                "Filterscript");
550          }
551          return false;
552        }
553      }
554      break;
555    }
556
557    case clang::Type::Pointer: {
558      if (IsFilterscript) {
559        if (ND) {
560          Context->ReportError(Loc,
561                               "Pointers are forbidden in Filterscript: '%0'")
562              << ND->getName();
563          return false;
564        } else {
565          // TODO(srhines): Find a better way to handle expressions (i.e. no
566          // NamedDecl) involving pointers in FS that should be allowed.
567          // An example would be calls to library functions like
568          // rsMatrixMultiply() that take rs_matrixNxN * types.
569        }
570      }
571
572      const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
573      const clang::Type *PointeeType = GetPointeeType(PT);
574
575      return ValidateTypeHelper(Context, C, PointeeType, ND, Loc, SPS,
576                                InCompositeType, UnionDecl, TargetAPI,
577                                IsFilterscript);
578    }
579
580    case clang::Type::ExtVector: {
581      const clang::ExtVectorType *EVT =
582              static_cast<const clang::ExtVectorType*>(CTI);
583      const clang::Type *ElementType = GetExtVectorElementType(EVT);
584      if (TargetAPI < SLANG_ICS_TARGET_API &&
585          InCompositeType &&
586          EVT->getNumElements() == 3 &&
587          ND &&
588          ND->getFormalLinkage() == clang::ExternalLinkage) {
589        ReportTypeError(Context, ND, nullptr,
590                        "structs containing vectors of dimension 3 cannot "
591                        "be exported at this API level: '%0'");
592        return false;
593      }
594      return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
595                                UnionDecl, TargetAPI, IsFilterscript);
596    }
597
598    case clang::Type::ConstantArray: {
599      const clang::ConstantArrayType *CAT = static_cast<const clang::ConstantArrayType*>(CTI);
600      const clang::Type *ElementType = GetConstantArrayElementType(CAT);
601      return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
602                                UnionDecl, TargetAPI, IsFilterscript);
603    }
604
605    default: {
606      break;
607    }
608  }
609
610  return true;
611}
612
613}  // namespace
614
615std::string CreateDummyName(const char *type, const std::string &name) {
616  std::stringstream S;
617  S << "<" << type;
618  if (!name.empty()) {
619    S << ":" << name;
620  }
621  S << ">";
622  return S.str();
623}
624
625/****************************** RSExportType ******************************/
626bool RSExportType::NormalizeType(const clang::Type *&T,
627                                 llvm::StringRef &TypeName,
628                                 RSContext *Context,
629                                 const clang::VarDecl *VD) {
630  if ((T = TypeExportable(T, Context, VD)) == nullptr) {
631    return false;
632  }
633  // Get type name
634  TypeName = RSExportType::GetTypeName(T);
635  if (Context && TypeName.empty()) {
636    if (VD) {
637      Context->ReportError(VD->getLocation(),
638                           "anonymous types cannot be exported");
639    } else {
640      Context->ReportError("anonymous types cannot be exported");
641    }
642    return false;
643  }
644
645  return true;
646}
647
648bool RSExportType::ValidateType(slang::RSContext *Context, clang::ASTContext &C,
649                                clang::QualType QT, clang::NamedDecl *ND,
650                                clang::SourceLocation Loc,
651                                unsigned int TargetAPI, bool IsFilterscript) {
652  const clang::Type *T = QT.getTypePtr();
653  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
654      llvm::SmallPtrSet<const clang::Type*, 8>();
655
656  return ValidateTypeHelper(Context, C, T, ND, Loc, SPS, false, nullptr, TargetAPI,
657                            IsFilterscript);
658  return true;
659}
660
661bool RSExportType::ValidateVarDecl(slang::RSContext *Context,
662                                   clang::VarDecl *VD, unsigned int TargetAPI,
663                                   bool IsFilterscript) {
664  return ValidateType(Context, VD->getASTContext(), VD->getType(), VD,
665                      VD->getLocation(), TargetAPI, IsFilterscript);
666}
667
668const clang::Type
669*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
670  if (DD) {
671    clang::QualType T = DD->getType();
672
673    if (T.isNull())
674      return nullptr;
675    else
676      return T.getTypePtr();
677  }
678  return nullptr;
679}
680
681llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
682  T = GetCanonicalType(T);
683  if (T == nullptr)
684    return llvm::StringRef();
685
686  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
687
688  switch (T->getTypeClass()) {
689    case clang::Type::Builtin: {
690      const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
691      BuiltinInfo *info = FindBuiltinType(BT->getKind());
692      if (info != nullptr) {
693        return info->cname[0];
694      }
695      slangAssert(false && "Unknown data type of the builtin");
696      break;
697    }
698    case clang::Type::Record: {
699      clang::RecordDecl *RD;
700      if (T->isStructureType()) {
701        RD = T->getAsStructureType()->getDecl();
702      } else {
703        break;
704      }
705
706      llvm::StringRef Name = RD->getName();
707      if (Name.empty()) {
708        if (RD->getTypedefNameForAnonDecl() != nullptr) {
709          Name = RD->getTypedefNameForAnonDecl()->getName();
710        }
711
712        if (Name.empty()) {
713          // Try to find a name from redeclaration (i.e. typedef)
714          for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
715                   RE = RD->redecls_end();
716               RI != RE;
717               RI++) {
718            slangAssert(*RI != nullptr && "cannot be NULL object");
719
720            Name = (*RI)->getName();
721            if (!Name.empty())
722              break;
723          }
724        }
725      }
726      return Name;
727    }
728    case clang::Type::Pointer: {
729      // "*" plus pointee name
730      const clang::PointerType *P = static_cast<const clang::PointerType*>(CTI);
731      const clang::Type *PT = GetPointeeType(P);
732      llvm::StringRef PointeeName;
733      if (NormalizeType(PT, PointeeName, nullptr, nullptr)) {
734        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
735        Name[0] = '*';
736        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
737        Name[PointeeName.size() + 1] = '\0';
738        return Name;
739      }
740      break;
741    }
742    case clang::Type::ExtVector: {
743      const clang::ExtVectorType *EVT =
744              static_cast<const clang::ExtVectorType*>(CTI);
745      return RSExportVectorType::GetTypeName(EVT);
746      break;
747    }
748    case clang::Type::ConstantArray : {
749      // Construct name for a constant array is too complicated.
750      return "<ConstantArray>";
751    }
752    default: {
753      break;
754    }
755  }
756
757  return llvm::StringRef();
758}
759
760
761RSExportType *RSExportType::Create(RSContext *Context,
762                                   const clang::Type *T,
763                                   const llvm::StringRef &TypeName) {
764  // Lookup the context to see whether the type was processed before.
765  // Newly created RSExportType will insert into context
766  // in RSExportType::RSExportType()
767  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
768
769  if (ETI != Context->export_types_end())
770    return ETI->second;
771
772  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
773
774  RSExportType *ET = nullptr;
775  switch (T->getTypeClass()) {
776    case clang::Type::Record: {
777      DataType dt = RSExportPrimitiveType::GetRSSpecificType(TypeName);
778      switch (dt) {
779        case DataTypeUnknown: {
780          // User-defined types
781          ET = RSExportRecordType::Create(Context,
782                                          T->getAsStructureType(),
783                                          TypeName);
784          break;
785        }
786        case DataTypeRSMatrix2x2: {
787          // 2 x 2 Matrix type
788          ET = RSExportMatrixType::Create(Context,
789                                          T->getAsStructureType(),
790                                          TypeName,
791                                          2);
792          break;
793        }
794        case DataTypeRSMatrix3x3: {
795          // 3 x 3 Matrix type
796          ET = RSExportMatrixType::Create(Context,
797                                          T->getAsStructureType(),
798                                          TypeName,
799                                          3);
800          break;
801        }
802        case DataTypeRSMatrix4x4: {
803          // 4 x 4 Matrix type
804          ET = RSExportMatrixType::Create(Context,
805                                          T->getAsStructureType(),
806                                          TypeName,
807                                          4);
808          break;
809        }
810        default: {
811          // Others are primitive types
812          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
813          break;
814        }
815      }
816      break;
817    }
818    case clang::Type::Builtin: {
819      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
820      break;
821    }
822    case clang::Type::Pointer: {
823      ET = RSExportPointerType::Create(Context,
824                                       static_cast<const clang::PointerType*>(CTI),
825                                       TypeName);
826      // FIXME: free the name (allocated in RSExportType::GetTypeName)
827      delete [] TypeName.data();
828      break;
829    }
830    case clang::Type::ExtVector: {
831      ET = RSExportVectorType::Create(Context,
832                                      static_cast<const clang::ExtVectorType*>(CTI),
833                                      TypeName);
834      break;
835    }
836    case clang::Type::ConstantArray: {
837      ET = RSExportConstantArrayType::Create(
838              Context,
839              static_cast<const clang::ConstantArrayType*>(CTI));
840      break;
841    }
842    default: {
843      Context->ReportError("unknown type cannot be exported: '%0'")
844          << T->getTypeClassName();
845      break;
846    }
847  }
848
849  return ET;
850}
851
852RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
853  llvm::StringRef TypeName;
854  if (NormalizeType(T, TypeName, Context, nullptr)) {
855    return Create(Context, T, TypeName);
856  } else {
857    return nullptr;
858  }
859}
860
861RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
862                                           const clang::VarDecl *VD) {
863  return RSExportType::Create(Context, GetTypeOfDecl(VD));
864}
865
866size_t RSExportType::getStoreSize() const {
867  return getRSContext()->getDataLayout()->getTypeStoreSize(getLLVMType());
868}
869
870size_t RSExportType::getAllocSize() const {
871    return getRSContext()->getDataLayout()->getTypeAllocSize(getLLVMType());
872}
873
874RSExportType::RSExportType(RSContext *Context,
875                           ExportClass Class,
876                           const llvm::StringRef &Name)
877    : RSExportable(Context, RSExportable::EX_TYPE),
878      mClass(Class),
879      // Make a copy on Name since memory stored @Name is either allocated in
880      // ASTContext or allocated in GetTypeName which will be destroyed later.
881      mName(Name.data(), Name.size()),
882      mLLVMType(nullptr) {
883  // Don't cache the type whose name start with '<'. Those type failed to
884  // get their name since constructing their name in GetTypeName() requiring
885  // complicated work.
886  if (!IsDummyName(Name)) {
887    // TODO(zonr): Need to check whether the insertion is successful or not.
888    Context->insertExportType(llvm::StringRef(Name), this);
889  }
890
891}
892
893bool RSExportType::keep() {
894  if (!RSExportable::keep())
895    return false;
896  // Invalidate converted LLVM type.
897  mLLVMType = nullptr;
898  return true;
899}
900
901bool RSExportType::equals(const RSExportable *E) const {
902  CHECK_PARENT_EQUALITY(RSExportable, E);
903  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
904}
905
906RSExportType::~RSExportType() {
907}
908
909/************************** RSExportPrimitiveType **************************/
910llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
911RSExportPrimitiveType::RSSpecificTypeMap;
912
913bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
914  if ((T != nullptr) && (T->getTypeClass() == clang::Type::Builtin))
915    return true;
916  else
917    return false;
918}
919
920DataType
921RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
922  if (TypeName.empty())
923    return DataTypeUnknown;
924
925  if (RSSpecificTypeMap->empty()) {
926    for (int i = 0; i < MatrixAndObjectDataTypesCount; i++) {
927      (*RSSpecificTypeMap)[MatrixAndObjectDataTypes[i].name] =
928          MatrixAndObjectDataTypes[i].dataType;
929    }
930  }
931
932  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
933  if (I == RSSpecificTypeMap->end())
934    return DataTypeUnknown;
935  else
936    return I->getValue();
937}
938
939DataType RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
940  T = GetCanonicalType(T);
941  if ((T == nullptr) || (T->getTypeClass() != clang::Type::Record))
942    return DataTypeUnknown;
943
944  return GetRSSpecificType( RSExportType::GetTypeName(T) );
945}
946
947bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
948    if (DT < 0 || DT >= DataTypeMax) {
949        return false;
950    }
951    return gReflectionTypes[DT].category == MatrixDataType;
952}
953
954bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
955    if (DT < 0 || DT >= DataTypeMax) {
956        return false;
957    }
958    return gReflectionTypes[DT].category == ObjectDataType;
959}
960
961bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
962  bool RSObjectTypeSeen = false;
963  while (T && T->isArrayType()) {
964    T = T->getArrayElementTypeNoTypeQual();
965  }
966
967  const clang::RecordType *RT = T->getAsStructureType();
968  if (!RT) {
969    return false;
970  }
971
972  const clang::RecordDecl *RD = RT->getDecl();
973  if (RD) {
974    RD = RD->getDefinition();
975  }
976  if (!RD) {
977    return false;
978  }
979
980  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
981         FE = RD->field_end();
982       FI != FE;
983       FI++) {
984    // We just look through all field declarations to see if we find a
985    // declaration for an RS object type (or an array of one).
986    const clang::FieldDecl *FD = *FI;
987    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
988    while (FT && FT->isArrayType()) {
989      FT = FT->getArrayElementTypeNoTypeQual();
990    }
991
992    DataType DT = GetRSSpecificType(FT);
993    if (IsRSObjectType(DT)) {
994      // RS object types definitely need to be zero-initialized
995      RSObjectTypeSeen = true;
996    } else {
997      switch (DT) {
998        case DataTypeRSMatrix2x2:
999        case DataTypeRSMatrix3x3:
1000        case DataTypeRSMatrix4x4:
1001          // Matrix types should get zero-initialized as well
1002          RSObjectTypeSeen = true;
1003          break;
1004        default:
1005          // Ignore all other primitive types
1006          break;
1007      }
1008      while (FT && FT->isArrayType()) {
1009        FT = FT->getArrayElementTypeNoTypeQual();
1010      }
1011      if (FT->isStructureType()) {
1012        // Recursively handle structs of structs (even though these can't
1013        // be exported, it is possible for a user to have them internally).
1014        RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
1015      }
1016    }
1017  }
1018
1019  return RSObjectTypeSeen;
1020}
1021
1022size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
1023  int type = EPT->getType();
1024  slangAssert((type > DataTypeUnknown && type < DataTypeMax) &&
1025              "RSExportPrimitiveType::GetSizeInBits : unknown data type");
1026  // All RS object types are 256 bits in 64-bit RS.
1027  if (EPT->isRSObjectType() && EPT->getRSContext()->is64Bit()) {
1028    return 256;
1029  }
1030  return gReflectionTypes[type].size_in_bits;
1031}
1032
1033DataType
1034RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
1035  if (T == nullptr)
1036    return DataTypeUnknown;
1037
1038  switch (T->getTypeClass()) {
1039    case clang::Type::Builtin: {
1040      const clang::BuiltinType *BT =
1041              static_cast<const clang::BuiltinType*>(T->getCanonicalTypeInternal().getTypePtr());
1042      BuiltinInfo *info = FindBuiltinType(BT->getKind());
1043      if (info != nullptr) {
1044        return info->type;
1045      }
1046      // The size of type WChar depend on platform so we abandon the support
1047      // to them.
1048      Context->ReportError("built-in type cannot be exported: '%0'")
1049          << T->getTypeClassName();
1050      break;
1051    }
1052    case clang::Type::Record: {
1053      // must be RS object type
1054      return RSExportPrimitiveType::GetRSSpecificType(T);
1055    }
1056    default: {
1057      Context->ReportError("primitive type cannot be exported: '%0'")
1058          << T->getTypeClassName();
1059      break;
1060    }
1061  }
1062
1063  return DataTypeUnknown;
1064}
1065
1066RSExportPrimitiveType
1067*RSExportPrimitiveType::Create(RSContext *Context,
1068                               const clang::Type *T,
1069                               const llvm::StringRef &TypeName,
1070                               bool Normalized) {
1071  DataType DT = GetDataType(Context, T);
1072
1073  if ((DT == DataTypeUnknown) || TypeName.empty())
1074    return nullptr;
1075  else
1076    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
1077                                     DT, Normalized);
1078}
1079
1080RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
1081                                                     const clang::Type *T) {
1082  llvm::StringRef TypeName;
1083  if (RSExportType::NormalizeType(T, TypeName, Context, nullptr)
1084      && IsPrimitiveType(T)) {
1085    return Create(Context, T, TypeName);
1086  } else {
1087    return nullptr;
1088  }
1089}
1090
1091llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
1092  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1093
1094  if (isRSObjectType()) {
1095    // struct {
1096    //   int *p;
1097    // } __attribute__((packed, aligned(pointer_size)))
1098    //
1099    // which is
1100    //
1101    // <{ [1 x i32] }> in LLVM
1102    //
1103    std::vector<llvm::Type *> Elements;
1104    if (getRSContext()->is64Bit()) {
1105      // 64-bit path
1106      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt64Ty(C), 4));
1107      return llvm::StructType::get(C, Elements, true);
1108    } else {
1109      // 32-bit legacy path
1110      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
1111      return llvm::StructType::get(C, Elements, true);
1112    }
1113  }
1114
1115  switch (mType) {
1116    case DataTypeFloat16: {
1117      return llvm::Type::getHalfTy(C);
1118      break;
1119    }
1120    case DataTypeFloat32: {
1121      return llvm::Type::getFloatTy(C);
1122      break;
1123    }
1124    case DataTypeFloat64: {
1125      return llvm::Type::getDoubleTy(C);
1126      break;
1127    }
1128    case DataTypeBoolean: {
1129      return llvm::Type::getInt1Ty(C);
1130      break;
1131    }
1132    case DataTypeSigned8:
1133    case DataTypeUnsigned8: {
1134      return llvm::Type::getInt8Ty(C);
1135      break;
1136    }
1137    case DataTypeSigned16:
1138    case DataTypeUnsigned16:
1139    case DataTypeUnsigned565:
1140    case DataTypeUnsigned5551:
1141    case DataTypeUnsigned4444: {
1142      return llvm::Type::getInt16Ty(C);
1143      break;
1144    }
1145    case DataTypeSigned32:
1146    case DataTypeUnsigned32: {
1147      return llvm::Type::getInt32Ty(C);
1148      break;
1149    }
1150    case DataTypeSigned64:
1151    case DataTypeUnsigned64: {
1152      return llvm::Type::getInt64Ty(C);
1153      break;
1154    }
1155    default: {
1156      slangAssert(false && "Unknown data type");
1157    }
1158  }
1159
1160  return nullptr;
1161}
1162
1163bool RSExportPrimitiveType::equals(const RSExportable *E) const {
1164  CHECK_PARENT_EQUALITY(RSExportType, E);
1165  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
1166}
1167
1168RSReflectionType *RSExportPrimitiveType::getRSReflectionType(DataType DT) {
1169  if (DT > DataTypeUnknown && DT < DataTypeMax) {
1170    return &gReflectionTypes[DT];
1171  } else {
1172    return nullptr;
1173  }
1174}
1175
1176/**************************** RSExportPointerType ****************************/
1177
1178RSExportPointerType
1179*RSExportPointerType::Create(RSContext *Context,
1180                             const clang::PointerType *PT,
1181                             const llvm::StringRef &TypeName) {
1182  const clang::Type *PointeeType = GetPointeeType(PT);
1183  const RSExportType *PointeeET;
1184
1185  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
1186    PointeeET = RSExportType::Create(Context, PointeeType);
1187  } else {
1188    // Double or higher dimension of pointer, export as int*
1189    PointeeET = RSExportPrimitiveType::Create(Context,
1190                    Context->getASTContext().IntTy.getTypePtr());
1191  }
1192
1193  if (PointeeET == nullptr) {
1194    // Error diagnostic is emitted for corresponding pointee type
1195    return nullptr;
1196  }
1197
1198  return new RSExportPointerType(Context, TypeName, PointeeET);
1199}
1200
1201llvm::Type *RSExportPointerType::convertToLLVMType() const {
1202  llvm::Type *PointeeType = mPointeeType->getLLVMType();
1203  return llvm::PointerType::getUnqual(PointeeType);
1204}
1205
1206bool RSExportPointerType::keep() {
1207  if (!RSExportType::keep())
1208    return false;
1209  const_cast<RSExportType*>(mPointeeType)->keep();
1210  return true;
1211}
1212
1213bool RSExportPointerType::equals(const RSExportable *E) const {
1214  CHECK_PARENT_EQUALITY(RSExportType, E);
1215  return (static_cast<const RSExportPointerType*>(E)
1216              ->getPointeeType()->equals(getPointeeType()));
1217}
1218
1219/***************************** RSExportVectorType *****************************/
1220llvm::StringRef
1221RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
1222  const clang::Type *ElementType = GetExtVectorElementType(EVT);
1223  llvm::StringRef name;
1224
1225  if ((ElementType->getTypeClass() != clang::Type::Builtin))
1226    return name;
1227
1228  const clang::BuiltinType *BT =
1229          static_cast<const clang::BuiltinType*>(
1230              ElementType->getCanonicalTypeInternal().getTypePtr());
1231
1232  if ((EVT->getNumElements() < 1) ||
1233      (EVT->getNumElements() > 4))
1234    return name;
1235
1236  BuiltinInfo *info = FindBuiltinType(BT->getKind());
1237  if (info != nullptr) {
1238    int I = EVT->getNumElements() - 1;
1239    if (I < kMaxVectorSize) {
1240      name = info->cname[I];
1241    } else {
1242      slangAssert(false && "Max vector is 4");
1243    }
1244  }
1245  return name;
1246}
1247
1248RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
1249                                               const clang::ExtVectorType *EVT,
1250                                               const llvm::StringRef &TypeName,
1251                                               bool Normalized) {
1252  slangAssert(EVT != nullptr && EVT->getTypeClass() == clang::Type::ExtVector);
1253
1254  const clang::Type *ElementType = GetExtVectorElementType(EVT);
1255  DataType DT = RSExportPrimitiveType::GetDataType(Context, ElementType);
1256
1257  if (DT != DataTypeUnknown)
1258    return new RSExportVectorType(Context,
1259                                  TypeName,
1260                                  DT,
1261                                  Normalized,
1262                                  EVT->getNumElements());
1263  else
1264    return nullptr;
1265}
1266
1267llvm::Type *RSExportVectorType::convertToLLVMType() const {
1268  llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
1269  return llvm::VectorType::get(ElementType, getNumElement());
1270}
1271
1272bool RSExportVectorType::equals(const RSExportable *E) const {
1273  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
1274  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
1275              == getNumElement());
1276}
1277
1278/***************************** RSExportMatrixType *****************************/
1279RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
1280                                               const clang::RecordType *RT,
1281                                               const llvm::StringRef &TypeName,
1282                                               unsigned Dim) {
1283  slangAssert((RT != nullptr) && (RT->getTypeClass() == clang::Type::Record));
1284  slangAssert((Dim > 1) && "Invalid dimension of matrix");
1285
1286  // Check whether the struct rs_matrix is in our expected form (but assume it's
1287  // correct if we're not sure whether it's correct or not)
1288  const clang::RecordDecl* RD = RT->getDecl();
1289  RD = RD->getDefinition();
1290  if (RD != nullptr) {
1291    // Find definition, perform further examination
1292    if (RD->field_empty()) {
1293      Context->ReportError(
1294          RD->getLocation(),
1295          "invalid matrix struct: must have 1 field for saving values: '%0'")
1296          << RD->getName();
1297      return nullptr;
1298    }
1299
1300    clang::RecordDecl::field_iterator FIT = RD->field_begin();
1301    const clang::FieldDecl *FD = *FIT;
1302    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1303    if ((FT == nullptr) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
1304      Context->ReportError(RD->getLocation(),
1305                           "invalid matrix struct: first field should"
1306                           " be an array with constant size: '%0'")
1307          << RD->getName();
1308      return nullptr;
1309    }
1310    const clang::ConstantArrayType *CAT =
1311      static_cast<const clang::ConstantArrayType *>(FT);
1312    const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1313    if ((ElementType == nullptr) ||
1314        (ElementType->getTypeClass() != clang::Type::Builtin) ||
1315        (static_cast<const clang::BuiltinType *>(ElementType)->getKind() !=
1316         clang::BuiltinType::Float)) {
1317      Context->ReportError(RD->getLocation(),
1318                           "invalid matrix struct: first field "
1319                           "should be a float array: '%0'")
1320          << RD->getName();
1321      return nullptr;
1322    }
1323
1324    if (CAT->getSize() != Dim * Dim) {
1325      Context->ReportError(RD->getLocation(),
1326                           "invalid matrix struct: first field "
1327                           "should be an array with size %0: '%1'")
1328          << (Dim * Dim) << (RD->getName());
1329      return nullptr;
1330    }
1331
1332    FIT++;
1333    if (FIT != RD->field_end()) {
1334      Context->ReportError(RD->getLocation(),
1335                           "invalid matrix struct: must have "
1336                           "exactly 1 field: '%0'")
1337          << RD->getName();
1338      return nullptr;
1339    }
1340  }
1341
1342  return new RSExportMatrixType(Context, TypeName, Dim);
1343}
1344
1345llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1346  // Construct LLVM type:
1347  // struct {
1348  //  float X[mDim * mDim];
1349  // }
1350
1351  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1352  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1353                                            mDim * mDim);
1354  return llvm::StructType::get(C, X, false);
1355}
1356
1357bool RSExportMatrixType::equals(const RSExportable *E) const {
1358  CHECK_PARENT_EQUALITY(RSExportType, E);
1359  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1360}
1361
1362/************************* RSExportConstantArrayType *************************/
1363RSExportConstantArrayType
1364*RSExportConstantArrayType::Create(RSContext *Context,
1365                                   const clang::ConstantArrayType *CAT) {
1366  slangAssert(CAT != nullptr && CAT->getTypeClass() == clang::Type::ConstantArray);
1367
1368  slangAssert((CAT->getSize().getActiveBits() < 32) && "array too large");
1369
1370  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1371  slangAssert((Size > 0) && "Constant array should have size greater than 0");
1372
1373  const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1374  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
1375
1376  if (ElementET == nullptr) {
1377    return nullptr;
1378  }
1379
1380  return new RSExportConstantArrayType(Context,
1381                                       ElementET,
1382                                       Size);
1383}
1384
1385llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1386  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
1387}
1388
1389bool RSExportConstantArrayType::keep() {
1390  if (!RSExportType::keep())
1391    return false;
1392  const_cast<RSExportType*>(mElementType)->keep();
1393  return true;
1394}
1395
1396bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1397  CHECK_PARENT_EQUALITY(RSExportType, E);
1398  const RSExportConstantArrayType *RHS =
1399      static_cast<const RSExportConstantArrayType*>(E);
1400  return ((getSize() == RHS->getSize()) &&
1401          (getElementType()->equals(RHS->getElementType())));
1402}
1403
1404/**************************** RSExportRecordType ****************************/
1405RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1406                                               const clang::RecordType *RT,
1407                                               const llvm::StringRef &TypeName,
1408                                               bool mIsArtificial) {
1409  slangAssert(RT != nullptr && RT->getTypeClass() == clang::Type::Record);
1410
1411  const clang::RecordDecl *RD = RT->getDecl();
1412  slangAssert(RD->isStruct());
1413
1414  RD = RD->getDefinition();
1415  if (RD == nullptr) {
1416    slangAssert(false && "struct is not defined in this module");
1417    return nullptr;
1418  }
1419
1420  // Struct layout construct by clang. We rely on this for obtaining the
1421  // alloc size of a struct and offset of every field in that struct.
1422  const clang::ASTRecordLayout *RL =
1423      &Context->getASTContext().getASTRecordLayout(RD);
1424  slangAssert((RL != nullptr) &&
1425      "Failed to retrieve the struct layout from Clang.");
1426
1427  RSExportRecordType *ERT =
1428      new RSExportRecordType(Context,
1429                             TypeName,
1430                             RD->hasAttr<clang::PackedAttr>(),
1431                             mIsArtificial,
1432                             RL->getDataSize().getQuantity(),
1433                             RL->getSize().getQuantity());
1434  unsigned int Index = 0;
1435
1436  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1437           FE = RD->field_end();
1438       FI != FE;
1439       FI++, Index++) {
1440
1441    // FIXME: All fields should be primitive type
1442    slangAssert(FI->getKind() == clang::Decl::Field);
1443    clang::FieldDecl *FD = *FI;
1444
1445    if (FD->isBitField()) {
1446      return nullptr;
1447    }
1448
1449    // Type
1450    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1451
1452    if (ET != nullptr) {
1453      ERT->mFields.push_back(
1454          new Field(ET, FD->getName(), ERT,
1455                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1456    } else {
1457      Context->ReportError(RD->getLocation(),
1458                           "field type cannot be exported: '%0.%1'")
1459          << RD->getName() << FD->getName();
1460      return nullptr;
1461    }
1462  }
1463
1464  return ERT;
1465}
1466
1467llvm::Type *RSExportRecordType::convertToLLVMType() const {
1468  // Create an opaque type since struct may reference itself recursively.
1469
1470  // TODO(sliao): LLVM took out the OpaqueType. Any other to migrate to?
1471  std::vector<llvm::Type*> FieldTypes;
1472
1473  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1474       FI != FE;
1475       FI++) {
1476    const Field *F = *FI;
1477    const RSExportType *FET = F->getType();
1478
1479    FieldTypes.push_back(FET->getLLVMType());
1480  }
1481
1482  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1483                                               FieldTypes,
1484                                               mIsPacked);
1485  if (ST != nullptr) {
1486    return ST;
1487  } else {
1488    return nullptr;
1489  }
1490}
1491
1492bool RSExportRecordType::keep() {
1493  if (!RSExportType::keep())
1494    return false;
1495  for (std::list<const Field*>::iterator I = mFields.begin(),
1496          E = mFields.end();
1497       I != E;
1498       I++) {
1499    const_cast<RSExportType*>((*I)->getType())->keep();
1500  }
1501  return true;
1502}
1503
1504bool RSExportRecordType::equals(const RSExportable *E) const {
1505  CHECK_PARENT_EQUALITY(RSExportType, E);
1506
1507  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1508
1509  if (ERT->getFields().size() != getFields().size())
1510    return false;
1511
1512  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1513
1514  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1515    if (!(*AI)->getType()->equals((*BI)->getType()))
1516      return false;
1517    AI++;
1518    BI++;
1519  }
1520
1521  return true;
1522}
1523
1524void RSExportType::convertToRTD(RSReflectionTypeData *rtd) const {
1525    memset(rtd, 0, sizeof(*rtd));
1526    rtd->vecSize = 1;
1527
1528    switch(getClass()) {
1529    case RSExportType::ExportClassPrimitive: {
1530            const RSExportPrimitiveType *EPT = static_cast<const RSExportPrimitiveType*>(this);
1531            rtd->type = RSExportPrimitiveType::getRSReflectionType(EPT);
1532            return;
1533        }
1534    case RSExportType::ExportClassPointer: {
1535            const RSExportPointerType *EPT = static_cast<const RSExportPointerType*>(this);
1536            const RSExportType *PointeeType = EPT->getPointeeType();
1537            PointeeType->convertToRTD(rtd);
1538            rtd->isPointer = true;
1539            return;
1540        }
1541    case RSExportType::ExportClassVector: {
1542            const RSExportVectorType *EVT = static_cast<const RSExportVectorType*>(this);
1543            rtd->type = EVT->getRSReflectionType(EVT);
1544            rtd->vecSize = EVT->getNumElement();
1545            return;
1546        }
1547    case RSExportType::ExportClassMatrix: {
1548            const RSExportMatrixType *EMT = static_cast<const RSExportMatrixType*>(this);
1549            unsigned Dim = EMT->getDim();
1550            slangAssert((Dim >= 2) && (Dim <= 4));
1551            rtd->type = &gReflectionTypes[15 + Dim-2];
1552            return;
1553        }
1554    case RSExportType::ExportClassConstantArray: {
1555            const RSExportConstantArrayType* CAT =
1556              static_cast<const RSExportConstantArrayType*>(this);
1557            CAT->getElementType()->convertToRTD(rtd);
1558            rtd->arraySize = CAT->getSize();
1559            return;
1560        }
1561    case RSExportType::ExportClassRecord: {
1562            slangAssert(!"RSExportType::ExportClassRecord not implemented");
1563            return;// RS_TYPE_CLASS_NAME_PREFIX + ET->getName() + ".Item";
1564        }
1565    default: {
1566            slangAssert(false && "Unknown class of type");
1567        }
1568    }
1569}
1570
1571
1572}  // namespace slang
1573