slang_rs_export_type.cpp revision 13fad85b3c99a37c17d8acfec72f46b8ee64e912
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/ASTContext.h"
23#include "clang/AST/Attr.h"
24#include "clang/AST/RecordLayout.h"
25
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/DerivedTypes.h"
29#include "llvm/IR/Type.h"
30
31#include "slang_assert.h"
32#include "slang_rs_context.h"
33#include "slang_rs_export_element.h"
34#include "slang_version.h"
35
36#define CHECK_PARENT_EQUALITY(ParentClass, E) \
37  if (!ParentClass::equals(E))                \
38    return false;
39
40namespace slang {
41
42namespace {
43
44// For the data types we support:
45//  Category      - data type category
46//  RsType        - element name in RenderScript
47//  RsShortType   - short element name in RenderScript
48//  SizeInBits    - size in bits
49//  CName         - reflected C name
50//  JavaName      - reflected Java name
51//  JavaArrayElementName - reflected name in Java arrays
52//  CVecName      - prefix for C vector types
53//  JavaVecName   - prefix for Java vector type
54//  JavaPromotion - unsigned type undergoing Java promotion
55//
56// IMPORTANT: The data types in this table should be at the same index as
57// specified by the corresponding DataType enum.
58//
59// TODO: Pull this information out into a separate file.
60static RSReflectionType gReflectionTypes[] = {
61#define _ nullptr
62{PrimitiveDataType,         "FLOAT_16",     "F16", 16,     "half",   "short",        _,   "Half",   "Half", false},
63{PrimitiveDataType,         "FLOAT_32",     "F32", 32,    "float",   "float",  "float",  "Float",  "Float", false},
64{PrimitiveDataType,         "FLOAT_64",     "F64", 64,   "double",  "double", "double", "Double", "Double", false},
65{PrimitiveDataType,         "SIGNED_8",      "I8",  8,   "int8_t",    "byte",   "byte",   "Byte",   "Byte", false},
66{PrimitiveDataType,        "SIGNED_16",     "I16", 16,  "int16_t",   "short",  "short",  "Short",  "Short", false},
67{PrimitiveDataType,        "SIGNED_32",     "I32", 32,  "int32_t",     "int",    "int",    "Int",    "Int", false},
68{PrimitiveDataType,        "SIGNED_64",     "I64", 64,  "int64_t",    "long",   "long",   "Long",   "Long", false},
69{PrimitiveDataType,       "UNSIGNED_8",      "U8",  8,  "uint8_t",   "short",   "byte",  "UByte",  "Short",  true},
70{PrimitiveDataType,      "UNSIGNED_16",     "U16", 16, "uint16_t",     "int",  "short", "UShort",    "Int",  true},
71{PrimitiveDataType,      "UNSIGNED_32",     "U32", 32, "uint32_t",    "long",    "int",   "UInt",   "Long",  true},
72{PrimitiveDataType,      "UNSIGNED_64",     "U64", 64, "uint64_t",    "long",   "long",  "ULong",   "Long", false},
73{PrimitiveDataType,          "BOOLEAN", "BOOLEAN",  8,     "bool", "boolean",   "byte",        _,        _, false},
74{PrimitiveDataType,   "UNSIGNED_5_6_5",         _, 16,          _,         _,        _,        _,        _, false},
75{PrimitiveDataType, "UNSIGNED_5_5_5_1",         _, 16,          _,         _,        _,        _,        _, false},
76{PrimitiveDataType, "UNSIGNED_4_4_4_4",         _, 16,          _,         _,        _,        _,        _, false},
77
78{MatrixDataType, "MATRIX_2X2", _,  4*32, "rs_matrix2x2", "Matrix2f", _, _, _, false},
79{MatrixDataType, "MATRIX_3X3", _,  9*32, "rs_matrix3x3", "Matrix3f", _, _, _, false},
80{MatrixDataType, "MATRIX_4X4", _, 16*32, "rs_matrix4x4", "Matrix4f", _, _, _, false},
81
82// RS object types are 32 bits in 32-bit RS, but 256 bits in 64-bit RS.
83// This is handled specially by the GetSizeInBits(}, method.
84{ObjectDataType,          "RS_ELEMENT",          "ELEMENT", 32,         "Element",         "Element", _, _, _, false},
85{ObjectDataType,             "RS_TYPE",             "TYPE", 32,            "Type",            "Type", _, _, _, false},
86{ObjectDataType,       "RS_ALLOCATION",       "ALLOCATION", 32,      "Allocation",      "Allocation", _, _, _, false},
87{ObjectDataType,          "RS_SAMPLER",          "SAMPLER", 32,         "Sampler",         "Sampler", _, _, _, false},
88{ObjectDataType,           "RS_SCRIPT",           "SCRIPT", 32,          "Script",          "Script", _, _, _, false},
89{ObjectDataType,             "RS_MESH",             "MESH", 32,            "Mesh",            "Mesh", _, _, _, false},
90{ObjectDataType,             "RS_PATH",             "PATH", 32,            "Path",            "Path", _, _, _, false},
91{ObjectDataType, "RS_PROGRAM_FRAGMENT", "PROGRAM_FRAGMENT", 32, "ProgramFragment", "ProgramFragment", _, _, _, false},
92{ObjectDataType,   "RS_PROGRAM_VERTEX",   "PROGRAM_VERTEX", 32,   "ProgramVertex",   "ProgramVertex", _, _, _, false},
93{ObjectDataType,   "RS_PROGRAM_RASTER",   "PROGRAM_RASTER", 32,   "ProgramRaster",   "ProgramRaster", _, _, _, false},
94{ObjectDataType,    "RS_PROGRAM_STORE",    "PROGRAM_STORE", 32,    "ProgramStore",    "ProgramStore", _, _, _, false},
95{ObjectDataType,             "RS_FONT",             "FONT", 32,            "Font",            "Font", _, _, _, false},
96#undef _
97};
98
99const int kMaxVectorSize = 4;
100
101struct BuiltinInfo {
102  clang::BuiltinType::Kind builtinTypeKind;
103  DataType type;
104  /* TODO If we return std::string instead of llvm::StringRef, we could build
105   * the name instead of duplicating the entries.
106   */
107  const char *cname[kMaxVectorSize];
108};
109
110
111BuiltinInfo BuiltinInfoTable[] = {
112    {clang::BuiltinType::Bool, DataTypeBoolean,
113     {"bool", "bool2", "bool3", "bool4"}},
114    {clang::BuiltinType::Char_U, DataTypeUnsigned8,
115     {"uchar", "uchar2", "uchar3", "uchar4"}},
116    {clang::BuiltinType::UChar, DataTypeUnsigned8,
117     {"uchar", "uchar2", "uchar3", "uchar4"}},
118    {clang::BuiltinType::Char16, DataTypeSigned16,
119     {"short", "short2", "short3", "short4"}},
120    {clang::BuiltinType::Char32, DataTypeSigned32,
121     {"int", "int2", "int3", "int4"}},
122    {clang::BuiltinType::UShort, DataTypeUnsigned16,
123     {"ushort", "ushort2", "ushort3", "ushort4"}},
124    {clang::BuiltinType::UInt, DataTypeUnsigned32,
125     {"uint", "uint2", "uint3", "uint4"}},
126    {clang::BuiltinType::ULong, DataTypeUnsigned64,
127     {"ulong", "ulong2", "ulong3", "ulong4"}},
128    {clang::BuiltinType::ULongLong, DataTypeUnsigned64,
129     {"ulong", "ulong2", "ulong3", "ulong4"}},
130
131    {clang::BuiltinType::Char_S, DataTypeSigned8,
132     {"char", "char2", "char3", "char4"}},
133    {clang::BuiltinType::SChar, DataTypeSigned8,
134     {"char", "char2", "char3", "char4"}},
135    {clang::BuiltinType::Short, DataTypeSigned16,
136     {"short", "short2", "short3", "short4"}},
137    {clang::BuiltinType::Int, DataTypeSigned32,
138     {"int", "int2", "int3", "int4"}},
139    {clang::BuiltinType::Long, DataTypeSigned64,
140     {"long", "long2", "long3", "long4"}},
141    {clang::BuiltinType::LongLong, DataTypeSigned64,
142     {"long", "long2", "long3", "long4"}},
143    {clang::BuiltinType::Half, DataTypeFloat16,
144     {"half", "half2", "half3", "half4"}},
145    {clang::BuiltinType::Float, DataTypeFloat32,
146     {"float", "float2", "float3", "float4"}},
147    {clang::BuiltinType::Double, DataTypeFloat64,
148     {"double", "double2", "double3", "double4"}},
149};
150const int BuiltinInfoTableCount = sizeof(BuiltinInfoTable) / sizeof(BuiltinInfoTable[0]);
151
152struct NameAndPrimitiveType {
153  const char *name;
154  DataType dataType;
155};
156
157static NameAndPrimitiveType MatrixAndObjectDataTypes[] = {
158    {"rs_matrix2x2", DataTypeRSMatrix2x2},
159    {"rs_matrix3x3", DataTypeRSMatrix3x3},
160    {"rs_matrix4x4", DataTypeRSMatrix4x4},
161    {"rs_element", DataTypeRSElement},
162    {"rs_type", DataTypeRSType},
163    {"rs_allocation", DataTypeRSAllocation},
164    {"rs_sampler", DataTypeRSSampler},
165    {"rs_script", DataTypeRSScript},
166    {"rs_mesh", DataTypeRSMesh},
167    {"rs_path", DataTypeRSPath},
168    {"rs_program_fragment", DataTypeRSProgramFragment},
169    {"rs_program_vertex", DataTypeRSProgramVertex},
170    {"rs_program_raster", DataTypeRSProgramRaster},
171    {"rs_program_store", DataTypeRSProgramStore},
172    {"rs_font", DataTypeRSFont},
173};
174
175const int MatrixAndObjectDataTypesCount =
176    sizeof(MatrixAndObjectDataTypes) / sizeof(MatrixAndObjectDataTypes[0]);
177
178static const clang::Type *TypeExportableHelper(
179    const clang::Type *T,
180    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
181    slang::RSContext *Context,
182    const clang::VarDecl *VD,
183    const clang::RecordDecl *TopLevelRecord,
184    ExportKind EK);
185
186template <unsigned N>
187static void ReportTypeError(slang::RSContext *Context,
188                            const clang::NamedDecl *ND,
189                            const clang::RecordDecl *TopLevelRecord,
190                            const char (&Message)[N],
191                            unsigned int TargetAPI = 0) {
192  // Attempt to use the type declaration first (if we have one).
193  // Fall back to the variable definition, if we are looking at something
194  // like an array declaration that can't be exported.
195  if (TopLevelRecord) {
196    Context->ReportError(TopLevelRecord->getLocation(), Message)
197        << TopLevelRecord->getName() << TargetAPI;
198  } else if (ND) {
199    Context->ReportError(ND->getLocation(), Message) << ND->getName()
200                                                     << TargetAPI;
201  } else {
202    slangAssert(false && "Variables should be validated before exporting");
203  }
204}
205
206static const clang::Type *ConstantArrayTypeExportableHelper(
207    const clang::ConstantArrayType *CAT,
208    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
209    slang::RSContext *Context,
210    const clang::VarDecl *VD,
211    const clang::RecordDecl *TopLevelRecord,
212    ExportKind EK) {
213  // Check element type
214  const clang::Type *ElementType = GetConstantArrayElementType(CAT);
215  if (ElementType->isArrayType()) {
216    ReportTypeError(Context, VD, TopLevelRecord,
217                    "multidimensional arrays cannot be exported: '%0'");
218    return nullptr;
219  } else if (ElementType->isExtVectorType()) {
220    const clang::ExtVectorType *EVT =
221        static_cast<const clang::ExtVectorType*>(ElementType);
222    unsigned numElements = EVT->getNumElements();
223
224    const clang::Type *BaseElementType = GetExtVectorElementType(EVT);
225    if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
226      ReportTypeError(Context, VD, TopLevelRecord,
227        "vectors of non-primitive types cannot be exported: '%0'");
228      return nullptr;
229    }
230
231    if (numElements == 3 && CAT->getSize() != 1) {
232      ReportTypeError(Context, VD, TopLevelRecord,
233        "arrays of width 3 vector types cannot be exported: '%0'");
234      return nullptr;
235    }
236  }
237
238  if (TypeExportableHelper(ElementType, SPS, Context, VD,
239                           TopLevelRecord, EK) == nullptr) {
240    return nullptr;
241  } else {
242    return CAT;
243  }
244}
245
246BuiltinInfo *FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind) {
247  for (int i = 0; i < BuiltinInfoTableCount; i++) {
248    if (builtinTypeKind == BuiltinInfoTable[i].builtinTypeKind) {
249      return &BuiltinInfoTable[i];
250    }
251  }
252  return nullptr;
253}
254
255static const clang::Type *TypeExportableHelper(
256    clang::Type const *T,
257    llvm::SmallPtrSet<clang::Type const *, 8> &SPS,
258    slang::RSContext *Context,
259    clang::VarDecl const *VD,
260    clang::RecordDecl const *TopLevelRecord,
261    ExportKind EK) {
262  // Normalize first
263  if ((T = GetCanonicalType(T)) == nullptr)
264    return nullptr;
265
266  if (SPS.count(T))
267    return T;
268
269  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
270
271  switch (T->getTypeClass()) {
272    case clang::Type::Builtin: {
273      const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
274      return FindBuiltinType(BT->getKind()) == nullptr ? nullptr : T;
275    }
276    case clang::Type::Record: {
277      if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
278        return T;  // RS object type, no further checks are needed
279      }
280
281      // Check internal struct
282      if (T->isUnionType()) {
283        ReportTypeError(Context, VD, T->getAsUnionType()->getDecl(),
284                        "unions cannot be exported: '%0'");
285        return nullptr;
286      } else if (!T->isStructureType()) {
287        slangAssert(false && "Unknown type cannot be exported");
288        return nullptr;
289      }
290
291      clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
292      if (RD != nullptr) {
293        RD = RD->getDefinition();
294        if (RD == nullptr) {
295          ReportTypeError(Context, nullptr, T->getAsStructureType()->getDecl(),
296                          "struct is not defined in this module");
297          return nullptr;
298        }
299      }
300
301      if (!TopLevelRecord) {
302        TopLevelRecord = RD;
303      }
304      if (RD->getName().empty()) {
305        ReportTypeError(Context, nullptr, RD,
306                        "anonymous structures cannot be exported");
307        return nullptr;
308      }
309
310      // Fast check
311      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
312        return nullptr;
313
314      // Insert myself into checking set
315      SPS.insert(T);
316
317      // Check all element
318      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
319               FE = RD->field_end();
320           FI != FE;
321           FI++) {
322        const clang::FieldDecl *FD = *FI;
323        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
324        FT = GetCanonicalType(FT);
325
326        if (!TypeExportableHelper(FT, SPS, Context, VD, TopLevelRecord,
327                                  EK)) {
328          return nullptr;
329        }
330
331        // We don't support bit fields yet
332        //
333        // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
334        if (FD->isBitField()) {
335          Context->ReportError(
336              FD->getLocation(),
337              "bit fields are not able to be exported: '%0.%1'")
338              << RD->getName() << FD->getName();
339          return nullptr;
340        }
341      }
342
343      return T;
344    }
345    case clang::Type::Pointer: {
346      if (TopLevelRecord) {
347        ReportTypeError(Context, VD, TopLevelRecord,
348            "structures containing pointers cannot be used as the type of "
349            "an exported global variable or the parameter to an exported "
350            "function: '%0'");
351        return nullptr;
352      }
353
354      const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
355      const clang::Type *PointeeType = GetPointeeType(PT);
356
357      if (PointeeType->getTypeClass() == clang::Type::Pointer) {
358        ReportTypeError(Context, VD, TopLevelRecord,
359            "multiple levels of pointers cannot be exported: '%0'");
360        return nullptr;
361      }
362
363      // Void pointers are forbidden for export, although we must accept
364      // void pointers that come in as arguments to a legacy kernel.
365      if (PointeeType->isVoidType() && EK != LegacyKernelArgument) {
366        ReportTypeError(Context, VD, TopLevelRecord,
367            "void pointers cannot be exported: '%0'");
368        return nullptr;
369      }
370
371      // We don't support pointer with array-type pointee or unsupported pointee
372      // type
373      if (PointeeType->isArrayType() ||
374          (TypeExportableHelper(PointeeType, SPS, Context, VD,
375                                TopLevelRecord, EK) == nullptr))
376        return nullptr;
377      else
378        return T;
379    }
380    case clang::Type::ExtVector: {
381      const clang::ExtVectorType *EVT =
382              static_cast<const clang::ExtVectorType*>(CTI);
383      // Only vector with size 2, 3 and 4 are supported.
384      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
385        return nullptr;
386
387      // Check base element type
388      const clang::Type *ElementType = GetExtVectorElementType(EVT);
389
390      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
391          (TypeExportableHelper(ElementType, SPS, Context, VD,
392                                TopLevelRecord, EK) == nullptr))
393        return nullptr;
394      else
395        return T;
396    }
397    case clang::Type::ConstantArray: {
398      const clang::ConstantArrayType *CAT =
399              static_cast<const clang::ConstantArrayType*>(CTI);
400
401      return ConstantArrayTypeExportableHelper(CAT, SPS, Context, VD,
402                                               TopLevelRecord, EK);
403    }
404    case clang::Type::Enum: {
405      // FIXME: We currently convert enums to integers, rather than reflecting
406      // a more complete (and nicer type-safe Java version).
407      return Context->getASTContext().IntTy.getTypePtr();
408    }
409    default: {
410      slangAssert(false && "Unknown type cannot be validated");
411      return nullptr;
412    }
413  }
414}
415
416// Return the type that can be used to create RSExportType, will always return
417// the canonical type.
418//
419// If the Type T is not exportable, this function returns nullptr. DiagEngine is
420// used to generate proper Clang diagnostic messages when a non-exportable type
421// is detected. TopLevelRecord is used to capture the highest struct (in the
422// case of a nested hierarchy) for detecting other types that cannot be exported
423// (mostly pointers within a struct).
424static const clang::Type *TypeExportable(const clang::Type *T,
425                                         slang::RSContext *Context,
426                                         const clang::VarDecl *VD,
427                                         ExportKind EK) {
428  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
429      llvm::SmallPtrSet<const clang::Type*, 8>();
430
431  return TypeExportableHelper(T, SPS, Context, VD, nullptr, EK);
432}
433
434static bool ValidateRSObjectInVarDecl(slang::RSContext *Context,
435                                      clang::VarDecl *VD, bool InCompositeType,
436                                      unsigned int TargetAPI) {
437  if (TargetAPI < SLANG_JB_TARGET_API) {
438    // Only if we are already in a composite type (like an array or structure).
439    if (InCompositeType) {
440      // Only if we are actually exported (i.e. non-static).
441      if (VD->hasLinkage() &&
442          (VD->getFormalLinkage() == clang::ExternalLinkage)) {
443        // Only if we are not a pointer to an object.
444        const clang::Type *T = GetCanonicalType(VD->getType().getTypePtr());
445        if (T->getTypeClass() != clang::Type::Pointer) {
446          ReportTypeError(Context, VD, nullptr,
447                          "arrays/structures containing RS object types "
448                          "cannot be exported in target API < %1: '%0'",
449                          SLANG_JB_TARGET_API);
450          return false;
451        }
452      }
453    }
454  }
455
456  return true;
457}
458
459// Helper function for ValidateType(). We do a recursive descent on the
460// type hierarchy to ensure that we can properly export/handle the
461// declaration.
462// \return true if the variable declaration is valid,
463//         false if it is invalid (along with proper diagnostics).
464//
465// C - ASTContext (for diagnostics + builtin types).
466// T - sub-type that we are validating.
467// ND - (optional) top-level named declaration that we are validating.
468// SPS - set of types we have already seen/validated.
469// InCompositeType - true if we are within an outer composite type.
470// UnionDecl - set if we are in a sub-type of a union.
471// TargetAPI - target SDK API level.
472// IsFilterscript - whether or not we are compiling for Filterscript
473// IsExtern - is this type externally visible (i.e. extern global or parameter
474//                                             to an extern function)
475static bool ValidateTypeHelper(
476    slang::RSContext *Context,
477    clang::ASTContext &C,
478    const clang::Type *&T,
479    clang::NamedDecl *ND,
480    clang::SourceLocation Loc,
481    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
482    bool InCompositeType,
483    clang::RecordDecl *UnionDecl,
484    unsigned int TargetAPI,
485    bool IsFilterscript,
486    bool IsExtern) {
487  if ((T = GetCanonicalType(T)) == nullptr)
488    return true;
489
490  if (SPS.count(T))
491    return true;
492
493  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
494
495  switch (T->getTypeClass()) {
496    case clang::Type::Record: {
497      if (RSExportPrimitiveType::IsRSObjectType(T)) {
498        clang::VarDecl *VD = (ND ? llvm::dyn_cast<clang::VarDecl>(ND) : nullptr);
499        if (VD && !ValidateRSObjectInVarDecl(Context, VD, InCompositeType,
500                                             TargetAPI)) {
501          return false;
502        }
503      }
504
505      if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
506        if (!UnionDecl) {
507          return true;
508        } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
509          ReportTypeError(Context, nullptr, UnionDecl,
510              "unions containing RS object types are not allowed");
511          return false;
512        }
513      }
514
515      clang::RecordDecl *RD = nullptr;
516
517      // Check internal struct
518      if (T->isUnionType()) {
519        RD = T->getAsUnionType()->getDecl();
520        UnionDecl = RD;
521      } else if (T->isStructureType()) {
522        RD = T->getAsStructureType()->getDecl();
523      } else {
524        slangAssert(false && "Unknown type cannot be exported");
525        return false;
526      }
527
528      if (RD != nullptr) {
529        RD = RD->getDefinition();
530        if (RD == nullptr) {
531          // FIXME
532          return true;
533        }
534      }
535
536      // Fast check
537      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
538        return false;
539
540      // Insert myself into checking set
541      SPS.insert(T);
542
543      // Check all elements
544      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
545               FE = RD->field_end();
546           FI != FE;
547           FI++) {
548        const clang::FieldDecl *FD = *FI;
549        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
550        FT = GetCanonicalType(FT);
551
552        if (!ValidateTypeHelper(Context, C, FT, ND, Loc, SPS, true, UnionDecl,
553                                TargetAPI, IsFilterscript, IsExtern)) {
554          return false;
555        }
556      }
557
558      return true;
559    }
560
561    case clang::Type::Builtin: {
562      if (IsFilterscript) {
563        clang::QualType QT = T->getCanonicalTypeInternal();
564        if (QT == C.DoubleTy ||
565            QT == C.LongDoubleTy ||
566            QT == C.LongTy ||
567            QT == C.LongLongTy) {
568          if (ND) {
569            Context->ReportError(
570                Loc,
571                "Builtin types > 32 bits in size are forbidden in "
572                "Filterscript: '%0'")
573                << ND->getName();
574          } else {
575            Context->ReportError(
576                Loc,
577                "Builtin types > 32 bits in size are forbidden in "
578                "Filterscript");
579          }
580          return false;
581        }
582      }
583      break;
584    }
585
586    case clang::Type::Pointer: {
587      if (IsFilterscript) {
588        if (ND) {
589          Context->ReportError(Loc,
590                               "Pointers are forbidden in Filterscript: '%0'")
591              << ND->getName();
592          return false;
593        } else {
594          // TODO(srhines): Find a better way to handle expressions (i.e. no
595          // NamedDecl) involving pointers in FS that should be allowed.
596          // An example would be calls to library functions like
597          // rsMatrixMultiply() that take rs_matrixNxN * types.
598        }
599      }
600
601      // Forbid pointers in structures that are externally visible.
602      if (InCompositeType && IsExtern) {
603        if (ND) {
604          Context->ReportError(Loc,
605              "structures containing pointers cannot be used as the type of "
606              "an exported global variable or the parameter to an exported "
607              "function: '%0'")
608            << ND->getName();
609        } else {
610          Context->ReportError(Loc,
611              "structures containing pointers cannot be used as the type of "
612              "an exported global variable or the parameter to an exported "
613              "function");
614        }
615        return false;
616      }
617
618      const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
619      const clang::Type *PointeeType = GetPointeeType(PT);
620
621      return ValidateTypeHelper(Context, C, PointeeType, ND, Loc, SPS,
622                                InCompositeType, UnionDecl, TargetAPI,
623                                IsFilterscript, IsExtern);
624    }
625
626    case clang::Type::ExtVector: {
627      const clang::ExtVectorType *EVT =
628              static_cast<const clang::ExtVectorType*>(CTI);
629      const clang::Type *ElementType = GetExtVectorElementType(EVT);
630      if (TargetAPI < SLANG_ICS_TARGET_API &&
631          InCompositeType &&
632          EVT->getNumElements() == 3 &&
633          ND &&
634          ND->getFormalLinkage() == clang::ExternalLinkage) {
635        ReportTypeError(Context, ND, nullptr,
636                        "structs containing vectors of dimension 3 cannot "
637                        "be exported at this API level: '%0'");
638        return false;
639      }
640      return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
641                                UnionDecl, TargetAPI, IsFilterscript, IsExtern);
642    }
643
644    case clang::Type::ConstantArray: {
645      const clang::ConstantArrayType *CAT = static_cast<const clang::ConstantArrayType*>(CTI);
646      const clang::Type *ElementType = GetConstantArrayElementType(CAT);
647      return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
648                                UnionDecl, TargetAPI, IsFilterscript, IsExtern);
649    }
650
651    default: {
652      break;
653    }
654  }
655
656  return true;
657}
658
659}  // namespace
660
661std::string CreateDummyName(const char *type, const std::string &name) {
662  std::stringstream S;
663  S << "<" << type;
664  if (!name.empty()) {
665    S << ":" << name;
666  }
667  S << ">";
668  return S.str();
669}
670
671/****************************** RSExportType ******************************/
672bool RSExportType::NormalizeType(const clang::Type *&T,
673                                 llvm::StringRef &TypeName,
674                                 RSContext *Context,
675                                 const clang::VarDecl *VD,
676                                 ExportKind EK) {
677  if ((T = TypeExportable(T, Context, VD, EK)) == nullptr) {
678    return false;
679  }
680  // Get type name
681  TypeName = RSExportType::GetTypeName(T);
682  if (Context && TypeName.empty()) {
683    if (VD) {
684      Context->ReportError(VD->getLocation(),
685                           "anonymous types cannot be exported");
686    } else {
687      Context->ReportError("anonymous types cannot be exported");
688    }
689    return false;
690  }
691
692  return true;
693}
694
695bool RSExportType::ValidateType(slang::RSContext *Context, clang::ASTContext &C,
696                                clang::QualType QT, clang::NamedDecl *ND,
697                                clang::SourceLocation Loc,
698                                unsigned int TargetAPI, bool IsFilterscript,
699                                bool IsExtern) {
700  const clang::Type *T = QT.getTypePtr();
701  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
702      llvm::SmallPtrSet<const clang::Type*, 8>();
703
704  // If this is an externally visible variable declaration, we check if the
705  // type is able to be exported first.
706  if (auto VD = llvm::dyn_cast_or_null<clang::VarDecl>(ND)) {
707    if (VD->getFormalLinkage() == clang::ExternalLinkage) {
708      if (!TypeExportable(T, Context, VD, NotLegacyKernelArgument)) {
709        return false;
710      }
711    }
712  }
713  return ValidateTypeHelper(Context, C, T, ND, Loc, SPS, false, nullptr, TargetAPI,
714                            IsFilterscript, IsExtern);
715}
716
717bool RSExportType::ValidateVarDecl(slang::RSContext *Context,
718                                   clang::VarDecl *VD, unsigned int TargetAPI,
719                                   bool IsFilterscript) {
720  return ValidateType(Context, VD->getASTContext(), VD->getType(), VD,
721                      VD->getLocation(), TargetAPI, IsFilterscript,
722                      (VD->getFormalLinkage() == clang::ExternalLinkage));
723}
724
725const clang::Type
726*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
727  if (DD) {
728    clang::QualType T = DD->getType();
729
730    if (T.isNull())
731      return nullptr;
732    else
733      return T.getTypePtr();
734  }
735  return nullptr;
736}
737
738llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
739  T = GetCanonicalType(T);
740  if (T == nullptr)
741    return llvm::StringRef();
742
743  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
744
745  switch (T->getTypeClass()) {
746    case clang::Type::Builtin: {
747      const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
748      BuiltinInfo *info = FindBuiltinType(BT->getKind());
749      if (info != nullptr) {
750        return info->cname[0];
751      }
752      slangAssert(false && "Unknown data type of the builtin");
753      break;
754    }
755    case clang::Type::Record: {
756      clang::RecordDecl *RD;
757      if (T->isStructureType()) {
758        RD = T->getAsStructureType()->getDecl();
759      } else {
760        break;
761      }
762
763      llvm::StringRef Name = RD->getName();
764      if (Name.empty()) {
765        if (RD->getTypedefNameForAnonDecl() != nullptr) {
766          Name = RD->getTypedefNameForAnonDecl()->getName();
767        }
768
769        if (Name.empty()) {
770          // Try to find a name from redeclaration (i.e. typedef)
771          for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
772                   RE = RD->redecls_end();
773               RI != RE;
774               RI++) {
775            slangAssert(*RI != nullptr && "cannot be NULL object");
776
777            Name = (*RI)->getName();
778            if (!Name.empty())
779              break;
780          }
781        }
782      }
783      return Name;
784    }
785    case clang::Type::Pointer: {
786      // "*" plus pointee name
787      const clang::PointerType *P = static_cast<const clang::PointerType*>(CTI);
788      const clang::Type *PT = GetPointeeType(P);
789      llvm::StringRef PointeeName;
790      if (NormalizeType(PT, PointeeName, nullptr, nullptr,
791                        NotLegacyKernelArgument)) {
792        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
793        Name[0] = '*';
794        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
795        Name[PointeeName.size() + 1] = '\0';
796        return Name;
797      }
798      break;
799    }
800    case clang::Type::ExtVector: {
801      const clang::ExtVectorType *EVT =
802              static_cast<const clang::ExtVectorType*>(CTI);
803      return RSExportVectorType::GetTypeName(EVT);
804      break;
805    }
806    case clang::Type::ConstantArray : {
807      // Construct name for a constant array is too complicated.
808      return "<ConstantArray>";
809    }
810    default: {
811      break;
812    }
813  }
814
815  return llvm::StringRef();
816}
817
818
819RSExportType *RSExportType::Create(RSContext *Context,
820                                   const clang::Type *T,
821                                   const llvm::StringRef &TypeName,
822                                   ExportKind EK) {
823  // Lookup the context to see whether the type was processed before.
824  // Newly created RSExportType will insert into context
825  // in RSExportType::RSExportType()
826  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
827
828  if (ETI != Context->export_types_end())
829    return ETI->second;
830
831  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
832
833  RSExportType *ET = nullptr;
834  switch (T->getTypeClass()) {
835    case clang::Type::Record: {
836      DataType dt = RSExportPrimitiveType::GetRSSpecificType(TypeName);
837      switch (dt) {
838        case DataTypeUnknown: {
839          // User-defined types
840          ET = RSExportRecordType::Create(Context,
841                                          T->getAsStructureType(),
842                                          TypeName);
843          break;
844        }
845        case DataTypeRSMatrix2x2: {
846          // 2 x 2 Matrix type
847          ET = RSExportMatrixType::Create(Context,
848                                          T->getAsStructureType(),
849                                          TypeName,
850                                          2);
851          break;
852        }
853        case DataTypeRSMatrix3x3: {
854          // 3 x 3 Matrix type
855          ET = RSExportMatrixType::Create(Context,
856                                          T->getAsStructureType(),
857                                          TypeName,
858                                          3);
859          break;
860        }
861        case DataTypeRSMatrix4x4: {
862          // 4 x 4 Matrix type
863          ET = RSExportMatrixType::Create(Context,
864                                          T->getAsStructureType(),
865                                          TypeName,
866                                          4);
867          break;
868        }
869        default: {
870          // Others are primitive types
871          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
872          break;
873        }
874      }
875      break;
876    }
877    case clang::Type::Builtin: {
878      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
879      break;
880    }
881    case clang::Type::Pointer: {
882      ET = RSExportPointerType::Create(Context,
883                                       static_cast<const clang::PointerType*>(CTI),
884                                       TypeName);
885      // FIXME: free the name (allocated in RSExportType::GetTypeName)
886      delete [] TypeName.data();
887      break;
888    }
889    case clang::Type::ExtVector: {
890      ET = RSExportVectorType::Create(Context,
891                                      static_cast<const clang::ExtVectorType*>(CTI),
892                                      TypeName);
893      break;
894    }
895    case clang::Type::ConstantArray: {
896      ET = RSExportConstantArrayType::Create(
897              Context,
898              static_cast<const clang::ConstantArrayType*>(CTI));
899      break;
900    }
901    default: {
902      Context->ReportError("unknown type cannot be exported: '%0'")
903          << T->getTypeClassName();
904      break;
905    }
906  }
907
908  return ET;
909}
910
911RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T,
912                                   ExportKind EK) {
913  llvm::StringRef TypeName;
914  if (NormalizeType(T, TypeName, Context, nullptr, EK)) {
915    return Create(Context, T, TypeName, EK);
916  } else {
917    return nullptr;
918  }
919}
920
921RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
922                                           const clang::VarDecl *VD) {
923  return RSExportType::Create(Context, GetTypeOfDecl(VD),
924                              NotLegacyKernelArgument);
925}
926
927size_t RSExportType::getStoreSize() const {
928  return getRSContext()->getDataLayout()->getTypeStoreSize(getLLVMType());
929}
930
931size_t RSExportType::getAllocSize() const {
932    return getRSContext()->getDataLayout()->getTypeAllocSize(getLLVMType());
933}
934
935RSExportType::RSExportType(RSContext *Context,
936                           ExportClass Class,
937                           const llvm::StringRef &Name)
938    : RSExportable(Context, RSExportable::EX_TYPE),
939      mClass(Class),
940      // Make a copy on Name since memory stored @Name is either allocated in
941      // ASTContext or allocated in GetTypeName which will be destroyed later.
942      mName(Name.data(), Name.size()),
943      mLLVMType(nullptr) {
944  // Don't cache the type whose name start with '<'. Those type failed to
945  // get their name since constructing their name in GetTypeName() requiring
946  // complicated work.
947  if (!IsDummyName(Name)) {
948    // TODO(zonr): Need to check whether the insertion is successful or not.
949    Context->insertExportType(llvm::StringRef(Name), this);
950  }
951
952}
953
954bool RSExportType::keep() {
955  if (!RSExportable::keep())
956    return false;
957  // Invalidate converted LLVM type.
958  mLLVMType = nullptr;
959  return true;
960}
961
962bool RSExportType::equals(const RSExportable *E) const {
963  CHECK_PARENT_EQUALITY(RSExportable, E);
964  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
965}
966
967RSExportType::~RSExportType() {
968}
969
970/************************** RSExportPrimitiveType **************************/
971llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
972RSExportPrimitiveType::RSSpecificTypeMap;
973
974bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
975  if ((T != nullptr) && (T->getTypeClass() == clang::Type::Builtin))
976    return true;
977  else
978    return false;
979}
980
981DataType
982RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
983  if (TypeName.empty())
984    return DataTypeUnknown;
985
986  if (RSSpecificTypeMap->empty()) {
987    for (int i = 0; i < MatrixAndObjectDataTypesCount; i++) {
988      (*RSSpecificTypeMap)[MatrixAndObjectDataTypes[i].name] =
989          MatrixAndObjectDataTypes[i].dataType;
990    }
991  }
992
993  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
994  if (I == RSSpecificTypeMap->end())
995    return DataTypeUnknown;
996  else
997    return I->getValue();
998}
999
1000DataType RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
1001  T = GetCanonicalType(T);
1002  if ((T == nullptr) || (T->getTypeClass() != clang::Type::Record))
1003    return DataTypeUnknown;
1004
1005  return GetRSSpecificType( RSExportType::GetTypeName(T) );
1006}
1007
1008bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
1009    if (DT < 0 || DT >= DataTypeMax) {
1010        return false;
1011    }
1012    return gReflectionTypes[DT].category == MatrixDataType;
1013}
1014
1015bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
1016    if (DT < 0 || DT >= DataTypeMax) {
1017        return false;
1018    }
1019    return gReflectionTypes[DT].category == ObjectDataType;
1020}
1021
1022bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
1023  bool RSObjectTypeSeen = false;
1024  while (T && T->isArrayType()) {
1025    T = T->getArrayElementTypeNoTypeQual();
1026  }
1027
1028  const clang::RecordType *RT = T->getAsStructureType();
1029  if (!RT) {
1030    return false;
1031  }
1032
1033  const clang::RecordDecl *RD = RT->getDecl();
1034  if (RD) {
1035    RD = RD->getDefinition();
1036  }
1037  if (!RD) {
1038    return false;
1039  }
1040
1041  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1042         FE = RD->field_end();
1043       FI != FE;
1044       FI++) {
1045    // We just look through all field declarations to see if we find a
1046    // declaration for an RS object type (or an array of one).
1047    const clang::FieldDecl *FD = *FI;
1048    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1049    while (FT && FT->isArrayType()) {
1050      FT = FT->getArrayElementTypeNoTypeQual();
1051    }
1052
1053    DataType DT = GetRSSpecificType(FT);
1054    if (IsRSObjectType(DT)) {
1055      // RS object types definitely need to be zero-initialized
1056      RSObjectTypeSeen = true;
1057    } else {
1058      switch (DT) {
1059        case DataTypeRSMatrix2x2:
1060        case DataTypeRSMatrix3x3:
1061        case DataTypeRSMatrix4x4:
1062          // Matrix types should get zero-initialized as well
1063          RSObjectTypeSeen = true;
1064          break;
1065        default:
1066          // Ignore all other primitive types
1067          break;
1068      }
1069      while (FT && FT->isArrayType()) {
1070        FT = FT->getArrayElementTypeNoTypeQual();
1071      }
1072      if (FT->isStructureType()) {
1073        // Recursively handle structs of structs (even though these can't
1074        // be exported, it is possible for a user to have them internally).
1075        RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
1076      }
1077    }
1078  }
1079
1080  return RSObjectTypeSeen;
1081}
1082
1083size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
1084  int type = EPT->getType();
1085  slangAssert((type > DataTypeUnknown && type < DataTypeMax) &&
1086              "RSExportPrimitiveType::GetSizeInBits : unknown data type");
1087  // All RS object types are 256 bits in 64-bit RS.
1088  if (EPT->isRSObjectType() && EPT->getRSContext()->is64Bit()) {
1089    return 256;
1090  }
1091  return gReflectionTypes[type].size_in_bits;
1092}
1093
1094DataType
1095RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
1096  if (T == nullptr)
1097    return DataTypeUnknown;
1098
1099  switch (T->getTypeClass()) {
1100    case clang::Type::Builtin: {
1101      const clang::BuiltinType *BT =
1102              static_cast<const clang::BuiltinType*>(T->getCanonicalTypeInternal().getTypePtr());
1103      BuiltinInfo *info = FindBuiltinType(BT->getKind());
1104      if (info != nullptr) {
1105        return info->type;
1106      }
1107      // The size of type WChar depend on platform so we abandon the support
1108      // to them.
1109      Context->ReportError("built-in type cannot be exported: '%0'")
1110          << T->getTypeClassName();
1111      break;
1112    }
1113    case clang::Type::Record: {
1114      // must be RS object type
1115      return RSExportPrimitiveType::GetRSSpecificType(T);
1116    }
1117    default: {
1118      Context->ReportError("primitive type cannot be exported: '%0'")
1119          << T->getTypeClassName();
1120      break;
1121    }
1122  }
1123
1124  return DataTypeUnknown;
1125}
1126
1127RSExportPrimitiveType
1128*RSExportPrimitiveType::Create(RSContext *Context,
1129                               const clang::Type *T,
1130                               const llvm::StringRef &TypeName,
1131                               bool Normalized) {
1132  DataType DT = GetDataType(Context, T);
1133
1134  if ((DT == DataTypeUnknown) || TypeName.empty())
1135    return nullptr;
1136  else
1137    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
1138                                     DT, Normalized);
1139}
1140
1141RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
1142                                                     const clang::Type *T) {
1143  llvm::StringRef TypeName;
1144  if (RSExportType::NormalizeType(T, TypeName, Context, nullptr,
1145                                  NotLegacyKernelArgument) &&
1146      IsPrimitiveType(T)) {
1147    return Create(Context, T, TypeName);
1148  } else {
1149    return nullptr;
1150  }
1151}
1152
1153llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
1154  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1155
1156  if (isRSObjectType()) {
1157    // struct {
1158    //   int *p;
1159    // } __attribute__((packed, aligned(pointer_size)))
1160    //
1161    // which is
1162    //
1163    // <{ [1 x i32] }> in LLVM
1164    //
1165    std::vector<llvm::Type *> Elements;
1166    if (getRSContext()->is64Bit()) {
1167      // 64-bit path
1168      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt64Ty(C), 4));
1169      return llvm::StructType::get(C, Elements, true);
1170    } else {
1171      // 32-bit legacy path
1172      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
1173      return llvm::StructType::get(C, Elements, true);
1174    }
1175  }
1176
1177  switch (mType) {
1178    case DataTypeFloat16: {
1179      return llvm::Type::getHalfTy(C);
1180      break;
1181    }
1182    case DataTypeFloat32: {
1183      return llvm::Type::getFloatTy(C);
1184      break;
1185    }
1186    case DataTypeFloat64: {
1187      return llvm::Type::getDoubleTy(C);
1188      break;
1189    }
1190    case DataTypeBoolean: {
1191      return llvm::Type::getInt1Ty(C);
1192      break;
1193    }
1194    case DataTypeSigned8:
1195    case DataTypeUnsigned8: {
1196      return llvm::Type::getInt8Ty(C);
1197      break;
1198    }
1199    case DataTypeSigned16:
1200    case DataTypeUnsigned16:
1201    case DataTypeUnsigned565:
1202    case DataTypeUnsigned5551:
1203    case DataTypeUnsigned4444: {
1204      return llvm::Type::getInt16Ty(C);
1205      break;
1206    }
1207    case DataTypeSigned32:
1208    case DataTypeUnsigned32: {
1209      return llvm::Type::getInt32Ty(C);
1210      break;
1211    }
1212    case DataTypeSigned64:
1213    case DataTypeUnsigned64: {
1214      return llvm::Type::getInt64Ty(C);
1215      break;
1216    }
1217    default: {
1218      slangAssert(false && "Unknown data type");
1219    }
1220  }
1221
1222  return nullptr;
1223}
1224
1225bool RSExportPrimitiveType::equals(const RSExportable *E) const {
1226  CHECK_PARENT_EQUALITY(RSExportType, E);
1227  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
1228}
1229
1230RSReflectionType *RSExportPrimitiveType::getRSReflectionType(DataType DT) {
1231  if (DT > DataTypeUnknown && DT < DataTypeMax) {
1232    return &gReflectionTypes[DT];
1233  } else {
1234    return nullptr;
1235  }
1236}
1237
1238/**************************** RSExportPointerType ****************************/
1239
1240RSExportPointerType
1241*RSExportPointerType::Create(RSContext *Context,
1242                             const clang::PointerType *PT,
1243                             const llvm::StringRef &TypeName) {
1244  const clang::Type *PointeeType = GetPointeeType(PT);
1245  const RSExportType *PointeeET;
1246
1247  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
1248    PointeeET = RSExportType::Create(Context, PointeeType,
1249                                     NotLegacyKernelArgument);
1250  } else {
1251    // Double or higher dimension of pointer, export as int*
1252    PointeeET = RSExportPrimitiveType::Create(Context,
1253                    Context->getASTContext().IntTy.getTypePtr());
1254  }
1255
1256  if (PointeeET == nullptr) {
1257    // Error diagnostic is emitted for corresponding pointee type
1258    return nullptr;
1259  }
1260
1261  return new RSExportPointerType(Context, TypeName, PointeeET);
1262}
1263
1264llvm::Type *RSExportPointerType::convertToLLVMType() const {
1265  llvm::Type *PointeeType = mPointeeType->getLLVMType();
1266  return llvm::PointerType::getUnqual(PointeeType);
1267}
1268
1269bool RSExportPointerType::keep() {
1270  if (!RSExportType::keep())
1271    return false;
1272  const_cast<RSExportType*>(mPointeeType)->keep();
1273  return true;
1274}
1275
1276bool RSExportPointerType::equals(const RSExportable *E) const {
1277  CHECK_PARENT_EQUALITY(RSExportType, E);
1278  return (static_cast<const RSExportPointerType*>(E)
1279              ->getPointeeType()->equals(getPointeeType()));
1280}
1281
1282/***************************** RSExportVectorType *****************************/
1283llvm::StringRef
1284RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
1285  const clang::Type *ElementType = GetExtVectorElementType(EVT);
1286  llvm::StringRef name;
1287
1288  if ((ElementType->getTypeClass() != clang::Type::Builtin))
1289    return name;
1290
1291  const clang::BuiltinType *BT =
1292          static_cast<const clang::BuiltinType*>(
1293              ElementType->getCanonicalTypeInternal().getTypePtr());
1294
1295  if ((EVT->getNumElements() < 1) ||
1296      (EVT->getNumElements() > 4))
1297    return name;
1298
1299  BuiltinInfo *info = FindBuiltinType(BT->getKind());
1300  if (info != nullptr) {
1301    int I = EVT->getNumElements() - 1;
1302    if (I < kMaxVectorSize) {
1303      name = info->cname[I];
1304    } else {
1305      slangAssert(false && "Max vector is 4");
1306    }
1307  }
1308  return name;
1309}
1310
1311RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
1312                                               const clang::ExtVectorType *EVT,
1313                                               const llvm::StringRef &TypeName,
1314                                               bool Normalized) {
1315  slangAssert(EVT != nullptr && EVT->getTypeClass() == clang::Type::ExtVector);
1316
1317  const clang::Type *ElementType = GetExtVectorElementType(EVT);
1318  DataType DT = RSExportPrimitiveType::GetDataType(Context, ElementType);
1319
1320  if (DT != DataTypeUnknown)
1321    return new RSExportVectorType(Context,
1322                                  TypeName,
1323                                  DT,
1324                                  Normalized,
1325                                  EVT->getNumElements());
1326  else
1327    return nullptr;
1328}
1329
1330llvm::Type *RSExportVectorType::convertToLLVMType() const {
1331  llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
1332  return llvm::VectorType::get(ElementType, getNumElement());
1333}
1334
1335bool RSExportVectorType::equals(const RSExportable *E) const {
1336  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
1337  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
1338              == getNumElement());
1339}
1340
1341/***************************** RSExportMatrixType *****************************/
1342RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
1343                                               const clang::RecordType *RT,
1344                                               const llvm::StringRef &TypeName,
1345                                               unsigned Dim) {
1346  slangAssert((RT != nullptr) && (RT->getTypeClass() == clang::Type::Record));
1347  slangAssert((Dim > 1) && "Invalid dimension of matrix");
1348
1349  // Check whether the struct rs_matrix is in our expected form (but assume it's
1350  // correct if we're not sure whether it's correct or not)
1351  const clang::RecordDecl* RD = RT->getDecl();
1352  RD = RD->getDefinition();
1353  if (RD != nullptr) {
1354    // Find definition, perform further examination
1355    if (RD->field_empty()) {
1356      Context->ReportError(
1357          RD->getLocation(),
1358          "invalid matrix struct: must have 1 field for saving values: '%0'")
1359          << RD->getName();
1360      return nullptr;
1361    }
1362
1363    clang::RecordDecl::field_iterator FIT = RD->field_begin();
1364    const clang::FieldDecl *FD = *FIT;
1365    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1366    if ((FT == nullptr) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
1367      Context->ReportError(RD->getLocation(),
1368                           "invalid matrix struct: first field should"
1369                           " be an array with constant size: '%0'")
1370          << RD->getName();
1371      return nullptr;
1372    }
1373    const clang::ConstantArrayType *CAT =
1374      static_cast<const clang::ConstantArrayType *>(FT);
1375    const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1376    if ((ElementType == nullptr) ||
1377        (ElementType->getTypeClass() != clang::Type::Builtin) ||
1378        (static_cast<const clang::BuiltinType *>(ElementType)->getKind() !=
1379         clang::BuiltinType::Float)) {
1380      Context->ReportError(RD->getLocation(),
1381                           "invalid matrix struct: first field "
1382                           "should be a float array: '%0'")
1383          << RD->getName();
1384      return nullptr;
1385    }
1386
1387    if (CAT->getSize() != Dim * Dim) {
1388      Context->ReportError(RD->getLocation(),
1389                           "invalid matrix struct: first field "
1390                           "should be an array with size %0: '%1'")
1391          << (Dim * Dim) << (RD->getName());
1392      return nullptr;
1393    }
1394
1395    FIT++;
1396    if (FIT != RD->field_end()) {
1397      Context->ReportError(RD->getLocation(),
1398                           "invalid matrix struct: must have "
1399                           "exactly 1 field: '%0'")
1400          << RD->getName();
1401      return nullptr;
1402    }
1403  }
1404
1405  return new RSExportMatrixType(Context, TypeName, Dim);
1406}
1407
1408llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1409  // Construct LLVM type:
1410  // struct {
1411  //  float X[mDim * mDim];
1412  // }
1413
1414  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1415  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1416                                            mDim * mDim);
1417  return llvm::StructType::get(C, X, false);
1418}
1419
1420bool RSExportMatrixType::equals(const RSExportable *E) const {
1421  CHECK_PARENT_EQUALITY(RSExportType, E);
1422  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1423}
1424
1425/************************* RSExportConstantArrayType *************************/
1426RSExportConstantArrayType
1427*RSExportConstantArrayType::Create(RSContext *Context,
1428                                   const clang::ConstantArrayType *CAT) {
1429  slangAssert(CAT != nullptr && CAT->getTypeClass() == clang::Type::ConstantArray);
1430
1431  slangAssert((CAT->getSize().getActiveBits() < 32) && "array too large");
1432
1433  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1434  slangAssert((Size > 0) && "Constant array should have size greater than 0");
1435
1436  const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1437  RSExportType *ElementET = RSExportType::Create(Context, ElementType,
1438                                                 NotLegacyKernelArgument);
1439
1440  if (ElementET == nullptr) {
1441    return nullptr;
1442  }
1443
1444  return new RSExportConstantArrayType(Context,
1445                                       ElementET,
1446                                       Size);
1447}
1448
1449llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1450  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
1451}
1452
1453bool RSExportConstantArrayType::keep() {
1454  if (!RSExportType::keep())
1455    return false;
1456  const_cast<RSExportType*>(mElementType)->keep();
1457  return true;
1458}
1459
1460bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1461  CHECK_PARENT_EQUALITY(RSExportType, E);
1462  const RSExportConstantArrayType *RHS =
1463      static_cast<const RSExportConstantArrayType*>(E);
1464  return ((getSize() == RHS->getSize()) &&
1465          (getElementType()->equals(RHS->getElementType())));
1466}
1467
1468/**************************** RSExportRecordType ****************************/
1469RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1470                                               const clang::RecordType *RT,
1471                                               const llvm::StringRef &TypeName,
1472                                               bool mIsArtificial) {
1473  slangAssert(RT != nullptr && RT->getTypeClass() == clang::Type::Record);
1474
1475  const clang::RecordDecl *RD = RT->getDecl();
1476  slangAssert(RD->isStruct());
1477
1478  RD = RD->getDefinition();
1479  if (RD == nullptr) {
1480    slangAssert(false && "struct is not defined in this module");
1481    return nullptr;
1482  }
1483
1484  // Struct layout construct by clang. We rely on this for obtaining the
1485  // alloc size of a struct and offset of every field in that struct.
1486  const clang::ASTRecordLayout *RL =
1487      &Context->getASTContext().getASTRecordLayout(RD);
1488  slangAssert((RL != nullptr) &&
1489      "Failed to retrieve the struct layout from Clang.");
1490
1491  RSExportRecordType *ERT =
1492      new RSExportRecordType(Context,
1493                             TypeName,
1494                             RD->hasAttr<clang::PackedAttr>(),
1495                             mIsArtificial,
1496                             RL->getDataSize().getQuantity(),
1497                             RL->getSize().getQuantity());
1498  unsigned int Index = 0;
1499
1500  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1501           FE = RD->field_end();
1502       FI != FE;
1503       FI++, Index++) {
1504
1505    // FIXME: All fields should be primitive type
1506    slangAssert(FI->getKind() == clang::Decl::Field);
1507    clang::FieldDecl *FD = *FI;
1508
1509    if (FD->isBitField()) {
1510      return nullptr;
1511    }
1512
1513    // Type
1514    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1515
1516    if (ET != nullptr) {
1517      ERT->mFields.push_back(
1518          new Field(ET, FD->getName(), ERT,
1519                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1520    } else {
1521      Context->ReportError(RD->getLocation(),
1522                           "field type cannot be exported: '%0.%1'")
1523          << RD->getName() << FD->getName();
1524      return nullptr;
1525    }
1526  }
1527
1528  return ERT;
1529}
1530
1531llvm::Type *RSExportRecordType::convertToLLVMType() const {
1532  // Create an opaque type since struct may reference itself recursively.
1533
1534  // TODO(sliao): LLVM took out the OpaqueType. Any other to migrate to?
1535  std::vector<llvm::Type*> FieldTypes;
1536
1537  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1538       FI != FE;
1539       FI++) {
1540    const Field *F = *FI;
1541    const RSExportType *FET = F->getType();
1542
1543    FieldTypes.push_back(FET->getLLVMType());
1544  }
1545
1546  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1547                                               FieldTypes,
1548                                               mIsPacked);
1549  if (ST != nullptr) {
1550    return ST;
1551  } else {
1552    return nullptr;
1553  }
1554}
1555
1556bool RSExportRecordType::keep() {
1557  if (!RSExportType::keep())
1558    return false;
1559  for (std::list<const Field*>::iterator I = mFields.begin(),
1560          E = mFields.end();
1561       I != E;
1562       I++) {
1563    const_cast<RSExportType*>((*I)->getType())->keep();
1564  }
1565  return true;
1566}
1567
1568bool RSExportRecordType::equals(const RSExportable *E) const {
1569  CHECK_PARENT_EQUALITY(RSExportType, E);
1570
1571  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1572
1573  if (ERT->getFields().size() != getFields().size())
1574    return false;
1575
1576  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1577
1578  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1579    if (!(*AI)->getType()->equals((*BI)->getType()))
1580      return false;
1581    AI++;
1582    BI++;
1583  }
1584
1585  return true;
1586}
1587
1588void RSExportType::convertToRTD(RSReflectionTypeData *rtd) const {
1589    memset(rtd, 0, sizeof(*rtd));
1590    rtd->vecSize = 1;
1591
1592    switch(getClass()) {
1593    case RSExportType::ExportClassPrimitive: {
1594            const RSExportPrimitiveType *EPT = static_cast<const RSExportPrimitiveType*>(this);
1595            rtd->type = RSExportPrimitiveType::getRSReflectionType(EPT);
1596            return;
1597        }
1598    case RSExportType::ExportClassPointer: {
1599            const RSExportPointerType *EPT = static_cast<const RSExportPointerType*>(this);
1600            const RSExportType *PointeeType = EPT->getPointeeType();
1601            PointeeType->convertToRTD(rtd);
1602            rtd->isPointer = true;
1603            return;
1604        }
1605    case RSExportType::ExportClassVector: {
1606            const RSExportVectorType *EVT = static_cast<const RSExportVectorType*>(this);
1607            rtd->type = EVT->getRSReflectionType(EVT);
1608            rtd->vecSize = EVT->getNumElement();
1609            return;
1610        }
1611    case RSExportType::ExportClassMatrix: {
1612            const RSExportMatrixType *EMT = static_cast<const RSExportMatrixType*>(this);
1613            unsigned Dim = EMT->getDim();
1614            slangAssert((Dim >= 2) && (Dim <= 4));
1615            rtd->type = &gReflectionTypes[15 + Dim-2];
1616            return;
1617        }
1618    case RSExportType::ExportClassConstantArray: {
1619            const RSExportConstantArrayType* CAT =
1620              static_cast<const RSExportConstantArrayType*>(this);
1621            CAT->getElementType()->convertToRTD(rtd);
1622            rtd->arraySize = CAT->getSize();
1623            return;
1624        }
1625    case RSExportType::ExportClassRecord: {
1626            slangAssert(!"RSExportType::ExportClassRecord not implemented");
1627            return;// RS_TYPE_CLASS_NAME_PREFIX + ET->getName() + ".Item";
1628        }
1629    default: {
1630            slangAssert(false && "Unknown class of type");
1631        }
1632    }
1633}
1634
1635
1636}  // namespace slang
1637