slang_rs_export_type.cpp revision 2907b2a2768bc32f75867513528c8d7419e44780
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/ASTContext.h"
23#include "clang/AST/Attr.h"
24#include "clang/AST/RecordLayout.h"
25
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/DerivedTypes.h"
29#include "llvm/IR/Type.h"
30
31#include "slang_assert.h"
32#include "slang_rs_context.h"
33#include "slang_rs_export_element.h"
34#include "slang_version.h"
35
36#define CHECK_PARENT_EQUALITY(ParentClass, E) \
37  if (!ParentClass::equals(E))                \
38    return false;
39
40namespace slang {
41
42namespace {
43
44// For the data types we support:
45//  Category      - data type category
46//  RsType        - element name in RenderScript
47//  RsShortType   - short element name in RenderScript
48//  SizeInBits    - size in bits
49//  CName         - reflected C name
50//  JavaName      - reflected Java name
51//  JavaArrayElementName - reflected name in Java arrays
52//  CVecName      - prefix for C vector types
53//  JavaVecName   - prefix for Java vector type
54//  JavaPromotion - unsigned type undergoing Java promotion
55//
56// IMPORTANT: The data types in this table should be at the same index as
57// specified by the corresponding DataType enum.
58//
59// TODO: Pull this information out into a separate file.
60static RSReflectionType gReflectionTypes[] = {
61#define _ nullptr
62  //      Category             RsType       RsST           CName         JN     JAEN        CVN       JVN      JP
63{PrimitiveDataType,         "FLOAT_16",     "F16", 16,     "half",   "short",        _,   "Half",   "Half", false},
64{PrimitiveDataType,         "FLOAT_32",     "F32", 32,    "float",   "float",  "float",  "Float",  "Float", false},
65{PrimitiveDataType,         "FLOAT_64",     "F64", 64,   "double",  "double", "double", "Double", "Double", false},
66{PrimitiveDataType,         "SIGNED_8",      "I8",  8,   "int8_t",    "byte",   "byte",   "Byte",   "Byte", false},
67{PrimitiveDataType,        "SIGNED_16",     "I16", 16,  "int16_t",   "short",  "short",  "Short",  "Short", false},
68{PrimitiveDataType,        "SIGNED_32",     "I32", 32,  "int32_t",     "int",    "int",    "Int",    "Int", false},
69{PrimitiveDataType,        "SIGNED_64",     "I64", 64,  "int64_t",    "long",   "long",   "Long",   "Long", false},
70{PrimitiveDataType,       "UNSIGNED_8",      "U8",  8,  "uint8_t",   "short",   "byte",  "UByte",  "Short",  true},
71{PrimitiveDataType,      "UNSIGNED_16",     "U16", 16, "uint16_t",     "int",  "short", "UShort",    "Int",  true},
72{PrimitiveDataType,      "UNSIGNED_32",     "U32", 32, "uint32_t",    "long",    "int",   "UInt",   "Long",  true},
73{PrimitiveDataType,      "UNSIGNED_64",     "U64", 64, "uint64_t",    "long",   "long",  "ULong",   "Long", false},
74{PrimitiveDataType,          "BOOLEAN", "BOOLEAN",  8,     "bool", "boolean",   "byte",        _,        _, false},
75{PrimitiveDataType,   "UNSIGNED_5_6_5",         _, 16,          _,         _,        _,        _,        _, false},
76{PrimitiveDataType, "UNSIGNED_5_5_5_1",         _, 16,          _,         _,        _,        _,        _, false},
77{PrimitiveDataType, "UNSIGNED_4_4_4_4",         _, 16,          _,         _,        _,        _,        _, false},
78
79{MatrixDataType, "MATRIX_2X2", _,  4*32, "rs_matrix2x2", "Matrix2f", _, _, _, false},
80{MatrixDataType, "MATRIX_3X3", _,  9*32, "rs_matrix3x3", "Matrix3f", _, _, _, false},
81{MatrixDataType, "MATRIX_4X4", _, 16*32, "rs_matrix4x4", "Matrix4f", _, _, _, false},
82
83// RS object types are 32 bits in 32-bit RS, but 256 bits in 64-bit RS.
84// This is handled specially by the GetElementSizeInBits() method.
85{ObjectDataType,          "RS_ELEMENT",          "ELEMENT", 32,         "Element",         "Element", _, _, _, false},
86{ObjectDataType,             "RS_TYPE",             "TYPE", 32,            "Type",            "Type", _, _, _, false},
87{ObjectDataType,       "RS_ALLOCATION",       "ALLOCATION", 32,      "Allocation",      "Allocation", _, _, _, false},
88{ObjectDataType,          "RS_SAMPLER",          "SAMPLER", 32,         "Sampler",         "Sampler", _, _, _, false},
89{ObjectDataType,           "RS_SCRIPT",           "SCRIPT", 32,          "Script",          "Script", _, _, _, false},
90{ObjectDataType,             "RS_MESH",             "MESH", 32,            "Mesh",            "Mesh", _, _, _, false},
91{ObjectDataType,             "RS_PATH",             "PATH", 32,            "Path",            "Path", _, _, _, false},
92{ObjectDataType, "RS_PROGRAM_FRAGMENT", "PROGRAM_FRAGMENT", 32, "ProgramFragment", "ProgramFragment", _, _, _, false},
93{ObjectDataType,   "RS_PROGRAM_VERTEX",   "PROGRAM_VERTEX", 32,   "ProgramVertex",   "ProgramVertex", _, _, _, false},
94{ObjectDataType,   "RS_PROGRAM_RASTER",   "PROGRAM_RASTER", 32,   "ProgramRaster",   "ProgramRaster", _, _, _, false},
95{ObjectDataType,    "RS_PROGRAM_STORE",    "PROGRAM_STORE", 32,    "ProgramStore",    "ProgramStore", _, _, _, false},
96{ObjectDataType,             "RS_FONT",             "FONT", 32,            "Font",            "Font", _, _, _, false},
97#undef _
98};
99
100const int kMaxVectorSize = 4;
101
102struct BuiltinInfo {
103  clang::BuiltinType::Kind builtinTypeKind;
104  DataType type;
105  /* TODO If we return std::string instead of llvm::StringRef, we could build
106   * the name instead of duplicating the entries.
107   */
108  const char *cname[kMaxVectorSize];
109};
110
111
112BuiltinInfo BuiltinInfoTable[] = {
113    {clang::BuiltinType::Bool, DataTypeBoolean,
114     {"bool", "bool2", "bool3", "bool4"}},
115    {clang::BuiltinType::Char_U, DataTypeUnsigned8,
116     {"uchar", "uchar2", "uchar3", "uchar4"}},
117    {clang::BuiltinType::UChar, DataTypeUnsigned8,
118     {"uchar", "uchar2", "uchar3", "uchar4"}},
119    {clang::BuiltinType::Char16, DataTypeSigned16,
120     {"short", "short2", "short3", "short4"}},
121    {clang::BuiltinType::Char32, DataTypeSigned32,
122     {"int", "int2", "int3", "int4"}},
123    {clang::BuiltinType::UShort, DataTypeUnsigned16,
124     {"ushort", "ushort2", "ushort3", "ushort4"}},
125    {clang::BuiltinType::UInt, DataTypeUnsigned32,
126     {"uint", "uint2", "uint3", "uint4"}},
127    {clang::BuiltinType::ULong, DataTypeUnsigned64,
128     {"ulong", "ulong2", "ulong3", "ulong4"}},
129    {clang::BuiltinType::ULongLong, DataTypeUnsigned64,
130     {"ulong", "ulong2", "ulong3", "ulong4"}},
131
132    {clang::BuiltinType::Char_S, DataTypeSigned8,
133     {"char", "char2", "char3", "char4"}},
134    {clang::BuiltinType::SChar, DataTypeSigned8,
135     {"char", "char2", "char3", "char4"}},
136    {clang::BuiltinType::Short, DataTypeSigned16,
137     {"short", "short2", "short3", "short4"}},
138    {clang::BuiltinType::Int, DataTypeSigned32,
139     {"int", "int2", "int3", "int4"}},
140    {clang::BuiltinType::Long, DataTypeSigned64,
141     {"long", "long2", "long3", "long4"}},
142    {clang::BuiltinType::LongLong, DataTypeSigned64,
143     {"long", "long2", "long3", "long4"}},
144    {clang::BuiltinType::Half, DataTypeFloat16,
145     {"half", "half2", "half3", "half4"}},
146    {clang::BuiltinType::Float, DataTypeFloat32,
147     {"float", "float2", "float3", "float4"}},
148    {clang::BuiltinType::Double, DataTypeFloat64,
149     {"double", "double2", "double3", "double4"}},
150};
151const int BuiltinInfoTableCount = sizeof(BuiltinInfoTable) / sizeof(BuiltinInfoTable[0]);
152
153struct NameAndPrimitiveType {
154  const char *name;
155  DataType dataType;
156};
157
158static NameAndPrimitiveType MatrixAndObjectDataTypes[] = {
159    {"rs_matrix2x2", DataTypeRSMatrix2x2},
160    {"rs_matrix3x3", DataTypeRSMatrix3x3},
161    {"rs_matrix4x4", DataTypeRSMatrix4x4},
162    {"rs_element", DataTypeRSElement},
163    {"rs_type", DataTypeRSType},
164    {"rs_allocation", DataTypeRSAllocation},
165    {"rs_sampler", DataTypeRSSampler},
166    {"rs_script", DataTypeRSScript},
167    {"rs_mesh", DataTypeRSMesh},
168    {"rs_path", DataTypeRSPath},
169    {"rs_program_fragment", DataTypeRSProgramFragment},
170    {"rs_program_vertex", DataTypeRSProgramVertex},
171    {"rs_program_raster", DataTypeRSProgramRaster},
172    {"rs_program_store", DataTypeRSProgramStore},
173    {"rs_font", DataTypeRSFont},
174};
175
176const int MatrixAndObjectDataTypesCount =
177    sizeof(MatrixAndObjectDataTypes) / sizeof(MatrixAndObjectDataTypes[0]);
178
179static const clang::Type *TypeExportableHelper(
180    const clang::Type *T,
181    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
182    slang::RSContext *Context,
183    const clang::VarDecl *VD,
184    const clang::RecordDecl *TopLevelRecord,
185    ExportKind EK);
186
187template <unsigned N>
188static void ReportTypeError(slang::RSContext *Context,
189                            const clang::NamedDecl *ND,
190                            const clang::RecordDecl *TopLevelRecord,
191                            const char (&Message)[N],
192                            unsigned int TargetAPI = 0) {
193  // Attempt to use the type declaration first (if we have one).
194  // Fall back to the variable definition, if we are looking at something
195  // like an array declaration that can't be exported.
196  if (TopLevelRecord) {
197    Context->ReportError(TopLevelRecord->getLocation(), Message)
198        << TopLevelRecord->getName() << TargetAPI;
199  } else if (ND) {
200    Context->ReportError(ND->getLocation(), Message) << ND->getName()
201                                                     << TargetAPI;
202  } else {
203    slangAssert(false && "Variables should be validated before exporting");
204  }
205}
206
207static const clang::Type *ConstantArrayTypeExportableHelper(
208    const clang::ConstantArrayType *CAT,
209    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
210    slang::RSContext *Context,
211    const clang::VarDecl *VD,
212    const clang::RecordDecl *TopLevelRecord,
213    ExportKind EK) {
214  // Check element type
215  const clang::Type *ElementType = GetConstantArrayElementType(CAT);
216  if (ElementType->isArrayType()) {
217    ReportTypeError(Context, VD, TopLevelRecord,
218                    "multidimensional arrays cannot be exported: '%0'");
219    return nullptr;
220  } else if (ElementType->isExtVectorType()) {
221    const clang::ExtVectorType *EVT =
222        static_cast<const clang::ExtVectorType*>(ElementType);
223    unsigned numElements = EVT->getNumElements();
224
225    const clang::Type *BaseElementType = GetExtVectorElementType(EVT);
226    if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
227      ReportTypeError(Context, VD, TopLevelRecord,
228        "vectors of non-primitive types cannot be exported: '%0'");
229      return nullptr;
230    }
231
232    if (numElements == 3 && CAT->getSize() != 1) {
233      ReportTypeError(Context, VD, TopLevelRecord,
234        "arrays of width 3 vector types cannot be exported: '%0'");
235      return nullptr;
236    }
237  }
238
239  if (TypeExportableHelper(ElementType, SPS, Context, VD,
240                           TopLevelRecord, EK) == nullptr) {
241    return nullptr;
242  } else {
243    return CAT;
244  }
245}
246
247BuiltinInfo *FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind) {
248  for (int i = 0; i < BuiltinInfoTableCount; i++) {
249    if (builtinTypeKind == BuiltinInfoTable[i].builtinTypeKind) {
250      return &BuiltinInfoTable[i];
251    }
252  }
253  return nullptr;
254}
255
256static const clang::Type *TypeExportableHelper(
257    clang::Type const *T,
258    llvm::SmallPtrSet<clang::Type const *, 8> &SPS,
259    slang::RSContext *Context,
260    clang::VarDecl const *VD,
261    clang::RecordDecl const *TopLevelRecord,
262    ExportKind EK) {
263  // Normalize first
264  if ((T = GetCanonicalType(T)) == nullptr)
265    return nullptr;
266
267  if (SPS.count(T))
268    return T;
269
270  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
271
272  switch (T->getTypeClass()) {
273    case clang::Type::Builtin: {
274      const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
275      return FindBuiltinType(BT->getKind()) == nullptr ? nullptr : T;
276    }
277    case clang::Type::Record: {
278      if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
279        return T;  // RS object type, no further checks are needed
280      }
281
282      // Check internal struct
283      if (T->isUnionType()) {
284        ReportTypeError(Context, VD, T->getAsUnionType()->getDecl(),
285                        "unions cannot be exported: '%0'");
286        return nullptr;
287      } else if (!T->isStructureType()) {
288        slangAssert(false && "Unknown type cannot be exported");
289        return nullptr;
290      }
291
292      clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
293      if (RD != nullptr) {
294        RD = RD->getDefinition();
295        if (RD == nullptr) {
296          ReportTypeError(Context, nullptr, T->getAsStructureType()->getDecl(),
297                          "struct is not defined in this module");
298          return nullptr;
299        }
300      }
301
302      if (!TopLevelRecord) {
303        TopLevelRecord = RD;
304      }
305      if (RD->getName().empty()) {
306        ReportTypeError(Context, nullptr, RD,
307                        "anonymous structures cannot be exported");
308        return nullptr;
309      }
310
311      // Fast check
312      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
313        return nullptr;
314
315      // Insert myself into checking set
316      SPS.insert(T);
317
318      // Check all element
319      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
320               FE = RD->field_end();
321           FI != FE;
322           FI++) {
323        const clang::FieldDecl *FD = *FI;
324        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
325        FT = GetCanonicalType(FT);
326
327        if (!TypeExportableHelper(FT, SPS, Context, VD, TopLevelRecord,
328                                  EK)) {
329          return nullptr;
330        }
331
332        // We don't support bit fields yet
333        //
334        // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
335        if (FD->isBitField()) {
336          Context->ReportError(
337              FD->getLocation(),
338              "bit fields are not able to be exported: '%0.%1'")
339              << RD->getName() << FD->getName();
340          return nullptr;
341        }
342      }
343
344      return T;
345    }
346    case clang::Type::FunctionProto:
347      ReportTypeError(Context, VD, TopLevelRecord,
348                      "function types cannot be exported: '%0'");
349      return nullptr;
350    case clang::Type::Pointer: {
351      if (TopLevelRecord) {
352        ReportTypeError(Context, VD, TopLevelRecord,
353            "structures containing pointers cannot be used as the type of "
354            "an exported global variable or the parameter to an exported "
355            "function: '%0'");
356        return nullptr;
357      }
358
359      const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
360      const clang::Type *PointeeType = GetPointeeType(PT);
361
362      if (PointeeType->getTypeClass() == clang::Type::Pointer) {
363        ReportTypeError(Context, VD, TopLevelRecord,
364            "multiple levels of pointers cannot be exported: '%0'");
365        return nullptr;
366      }
367
368      // Void pointers are forbidden for export, although we must accept
369      // void pointers that come in as arguments to a legacy kernel.
370      if (PointeeType->isVoidType() && EK != LegacyKernelArgument) {
371        ReportTypeError(Context, VD, TopLevelRecord,
372            "void pointers cannot be exported: '%0'");
373        return nullptr;
374      }
375
376      // We don't support pointer with array-type pointee or unsupported pointee
377      // type
378      if (PointeeType->isArrayType() ||
379          (TypeExportableHelper(PointeeType, SPS, Context, VD,
380                                TopLevelRecord, EK) == nullptr))
381        return nullptr;
382      else
383        return T;
384    }
385    case clang::Type::ExtVector: {
386      const clang::ExtVectorType *EVT =
387              static_cast<const clang::ExtVectorType*>(CTI);
388      // Only vector with size 2, 3 and 4 are supported.
389      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
390        return nullptr;
391
392      // Check base element type
393      const clang::Type *ElementType = GetExtVectorElementType(EVT);
394
395      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
396          (TypeExportableHelper(ElementType, SPS, Context, VD,
397                                TopLevelRecord, EK) == nullptr))
398        return nullptr;
399      else
400        return T;
401    }
402    case clang::Type::ConstantArray: {
403      const clang::ConstantArrayType *CAT =
404              static_cast<const clang::ConstantArrayType*>(CTI);
405
406      return ConstantArrayTypeExportableHelper(CAT, SPS, Context, VD,
407                                               TopLevelRecord, EK);
408    }
409    case clang::Type::Enum: {
410      // FIXME: We currently convert enums to integers, rather than reflecting
411      // a more complete (and nicer type-safe Java version).
412      return Context->getASTContext().IntTy.getTypePtr();
413    }
414    default: {
415      slangAssert(false && "Unknown type cannot be validated");
416      return nullptr;
417    }
418  }
419}
420
421// Return the type that can be used to create RSExportType, will always return
422// the canonical type.
423//
424// If the Type T is not exportable, this function returns nullptr. DiagEngine is
425// used to generate proper Clang diagnostic messages when a non-exportable type
426// is detected. TopLevelRecord is used to capture the highest struct (in the
427// case of a nested hierarchy) for detecting other types that cannot be exported
428// (mostly pointers within a struct).
429static const clang::Type *TypeExportable(const clang::Type *T,
430                                         slang::RSContext *Context,
431                                         const clang::VarDecl *VD,
432                                         ExportKind EK) {
433  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
434      llvm::SmallPtrSet<const clang::Type*, 8>();
435
436  return TypeExportableHelper(T, SPS, Context, VD, nullptr, EK);
437}
438
439static bool ValidateRSObjectInVarDecl(slang::RSContext *Context,
440                                      const clang::VarDecl *VD, bool InCompositeType,
441                                      unsigned int TargetAPI) {
442  if (TargetAPI < SLANG_JB_TARGET_API) {
443    // Only if we are already in a composite type (like an array or structure).
444    if (InCompositeType) {
445      // Only if we are actually exported (i.e. non-static).
446      if (VD->hasLinkage() &&
447          (VD->getFormalLinkage() == clang::ExternalLinkage)) {
448        // Only if we are not a pointer to an object.
449        const clang::Type *T = GetCanonicalType(VD->getType().getTypePtr());
450        if (T->getTypeClass() != clang::Type::Pointer) {
451          ReportTypeError(Context, VD, nullptr,
452                          "arrays/structures containing RS object types "
453                          "cannot be exported in target API < %1: '%0'",
454                          SLANG_JB_TARGET_API);
455          return false;
456        }
457      }
458    }
459  }
460
461  return true;
462}
463
464// Helper function for ValidateType(). We do a recursive descent on the
465// type hierarchy to ensure that we can properly export/handle the
466// declaration.
467// \return true if the variable declaration is valid,
468//         false if it is invalid (along with proper diagnostics).
469//
470// C - ASTContext (for diagnostics + builtin types).
471// T - sub-type that we are validating.
472// ND - (optional) top-level named declaration that we are validating.
473// SPS - set of types we have already seen/validated.
474// InCompositeType - true if we are within an outer composite type.
475// UnionDecl - set if we are in a sub-type of a union.
476// TargetAPI - target SDK API level.
477// IsFilterscript - whether or not we are compiling for Filterscript
478// IsExtern - is this type externally visible (i.e. extern global or parameter
479//                                             to an extern function)
480static bool ValidateTypeHelper(
481    slang::RSContext *Context,
482    clang::ASTContext &C,
483    const clang::Type *&T,
484    const clang::NamedDecl *ND,
485    clang::SourceLocation Loc,
486    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
487    bool InCompositeType,
488    clang::RecordDecl *UnionDecl,
489    unsigned int TargetAPI,
490    bool IsFilterscript,
491    bool IsExtern) {
492  if ((T = GetCanonicalType(T)) == nullptr)
493    return true;
494
495  if (SPS.count(T))
496    return true;
497
498  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
499
500  switch (T->getTypeClass()) {
501    case clang::Type::Record: {
502      if (RSExportPrimitiveType::IsRSObjectType(T)) {
503        const clang::VarDecl *VD = (ND ? llvm::dyn_cast<clang::VarDecl>(ND) : nullptr);
504        if (VD && !ValidateRSObjectInVarDecl(Context, VD, InCompositeType,
505                                             TargetAPI)) {
506          return false;
507        }
508      }
509
510      if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
511        if (!UnionDecl) {
512          return true;
513        } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
514          ReportTypeError(Context, nullptr, UnionDecl,
515              "unions containing RS object types are not allowed");
516          return false;
517        }
518      }
519
520      clang::RecordDecl *RD = nullptr;
521
522      // Check internal struct
523      if (T->isUnionType()) {
524        RD = T->getAsUnionType()->getDecl();
525        UnionDecl = RD;
526      } else if (T->isStructureType()) {
527        RD = T->getAsStructureType()->getDecl();
528      } else {
529        slangAssert(false && "Unknown type cannot be exported");
530        return false;
531      }
532
533      if (RD != nullptr) {
534        RD = RD->getDefinition();
535        if (RD == nullptr) {
536          // FIXME
537          return true;
538        }
539      }
540
541      // Fast check
542      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
543        return false;
544
545      // Insert myself into checking set
546      SPS.insert(T);
547
548      // Check all elements
549      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
550               FE = RD->field_end();
551           FI != FE;
552           FI++) {
553        const clang::FieldDecl *FD = *FI;
554        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
555        FT = GetCanonicalType(FT);
556
557        if (!ValidateTypeHelper(Context, C, FT, ND, Loc, SPS, true, UnionDecl,
558                                TargetAPI, IsFilterscript, IsExtern)) {
559          return false;
560        }
561      }
562
563      return true;
564    }
565
566    case clang::Type::Builtin: {
567      if (IsFilterscript) {
568        clang::QualType QT = T->getCanonicalTypeInternal();
569        if (QT == C.DoubleTy ||
570            QT == C.LongDoubleTy ||
571            QT == C.LongTy ||
572            QT == C.LongLongTy) {
573          if (ND) {
574            Context->ReportError(
575                Loc,
576                "Builtin types > 32 bits in size are forbidden in "
577                "Filterscript: '%0'")
578                << ND->getName();
579          } else {
580            Context->ReportError(
581                Loc,
582                "Builtin types > 32 bits in size are forbidden in "
583                "Filterscript");
584          }
585          return false;
586        }
587      }
588      break;
589    }
590
591    case clang::Type::Pointer: {
592      if (IsFilterscript) {
593        if (ND) {
594          Context->ReportError(Loc,
595                               "Pointers are forbidden in Filterscript: '%0'")
596              << ND->getName();
597          return false;
598        } else {
599          // TODO(srhines): Find a better way to handle expressions (i.e. no
600          // NamedDecl) involving pointers in FS that should be allowed.
601          // An example would be calls to library functions like
602          // rsMatrixMultiply() that take rs_matrixNxN * types.
603        }
604      }
605
606      // Forbid pointers in structures that are externally visible.
607      if (InCompositeType && IsExtern) {
608        if (ND) {
609          Context->ReportError(Loc,
610              "structures containing pointers cannot be used as the type of "
611              "an exported global variable or the parameter to an exported "
612              "function: '%0'")
613            << ND->getName();
614        } else {
615          Context->ReportError(Loc,
616              "structures containing pointers cannot be used as the type of "
617              "an exported global variable or the parameter to an exported "
618              "function");
619        }
620        return false;
621      }
622
623      const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
624      const clang::Type *PointeeType = GetPointeeType(PT);
625
626      return ValidateTypeHelper(Context, C, PointeeType, ND, Loc, SPS,
627                                InCompositeType, UnionDecl, TargetAPI,
628                                IsFilterscript, IsExtern);
629    }
630
631    case clang::Type::ExtVector: {
632      const clang::ExtVectorType *EVT =
633              static_cast<const clang::ExtVectorType*>(CTI);
634      const clang::Type *ElementType = GetExtVectorElementType(EVT);
635      if (TargetAPI < SLANG_ICS_TARGET_API &&
636          InCompositeType &&
637          EVT->getNumElements() == 3 &&
638          ND &&
639          ND->getFormalLinkage() == clang::ExternalLinkage) {
640        ReportTypeError(Context, ND, nullptr,
641                        "structs containing vectors of dimension 3 cannot "
642                        "be exported at this API level: '%0'");
643        return false;
644      }
645      return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
646                                UnionDecl, TargetAPI, IsFilterscript, IsExtern);
647    }
648
649    case clang::Type::ConstantArray: {
650      const clang::ConstantArrayType *CAT = static_cast<const clang::ConstantArrayType*>(CTI);
651      const clang::Type *ElementType = GetConstantArrayElementType(CAT);
652      return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
653                                UnionDecl, TargetAPI, IsFilterscript, IsExtern);
654    }
655
656    default: {
657      break;
658    }
659  }
660
661  return true;
662}
663
664}  // namespace
665
666std::string CreateDummyName(const char *type, const std::string &name) {
667  std::stringstream S;
668  S << "<" << type;
669  if (!name.empty()) {
670    S << ":" << name;
671  }
672  S << ">";
673  return S.str();
674}
675
676/****************************** RSExportType ******************************/
677bool RSExportType::NormalizeType(const clang::Type *&T,
678                                 llvm::StringRef &TypeName,
679                                 RSContext *Context,
680                                 const clang::VarDecl *VD,
681                                 ExportKind EK) {
682  if ((T = TypeExportable(T, Context, VD, EK)) == nullptr) {
683    return false;
684  }
685  // Get type name
686  TypeName = RSExportType::GetTypeName(T);
687  if (Context && TypeName.empty()) {
688    if (VD) {
689      Context->ReportError(VD->getLocation(),
690                           "anonymous types cannot be exported");
691    } else {
692      Context->ReportError("anonymous types cannot be exported");
693    }
694    return false;
695  }
696
697  return true;
698}
699
700bool RSExportType::ValidateType(slang::RSContext *Context, clang::ASTContext &C,
701                                clang::QualType QT, const clang::NamedDecl *ND,
702                                clang::SourceLocation Loc,
703                                unsigned int TargetAPI, bool IsFilterscript,
704                                bool IsExtern) {
705  const clang::Type *T = QT.getTypePtr();
706  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
707      llvm::SmallPtrSet<const clang::Type*, 8>();
708
709  // If this is an externally visible variable declaration, we check if the
710  // type is able to be exported first.
711  if (auto VD = llvm::dyn_cast_or_null<clang::VarDecl>(ND)) {
712    if (VD->getFormalLinkage() == clang::ExternalLinkage) {
713      if (!TypeExportable(T, Context, VD, NotLegacyKernelArgument)) {
714        return false;
715      }
716    }
717  }
718  return ValidateTypeHelper(Context, C, T, ND, Loc, SPS, false, nullptr, TargetAPI,
719                            IsFilterscript, IsExtern);
720}
721
722bool RSExportType::ValidateVarDecl(slang::RSContext *Context,
723                                   clang::VarDecl *VD, unsigned int TargetAPI,
724                                   bool IsFilterscript) {
725  return ValidateType(Context, VD->getASTContext(), VD->getType(), VD,
726                      VD->getLocation(), TargetAPI, IsFilterscript,
727                      (VD->getFormalLinkage() == clang::ExternalLinkage));
728}
729
730const clang::Type
731*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
732  if (DD) {
733    clang::QualType T = DD->getType();
734
735    if (T.isNull())
736      return nullptr;
737    else
738      return T.getTypePtr();
739  }
740  return nullptr;
741}
742
743llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
744  T = GetCanonicalType(T);
745  if (T == nullptr)
746    return llvm::StringRef();
747
748  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
749
750  switch (T->getTypeClass()) {
751    case clang::Type::Builtin: {
752      const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
753      BuiltinInfo *info = FindBuiltinType(BT->getKind());
754      if (info != nullptr) {
755        return info->cname[0];
756      }
757      slangAssert(false && "Unknown data type of the builtin");
758      break;
759    }
760    case clang::Type::Record: {
761      clang::RecordDecl *RD;
762      if (T->isStructureType()) {
763        RD = T->getAsStructureType()->getDecl();
764      } else {
765        break;
766      }
767
768      llvm::StringRef Name = RD->getName();
769      if (Name.empty()) {
770        if (RD->getTypedefNameForAnonDecl() != nullptr) {
771          Name = RD->getTypedefNameForAnonDecl()->getName();
772        }
773
774        if (Name.empty()) {
775          // Try to find a name from redeclaration (i.e. typedef)
776          for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
777                   RE = RD->redecls_end();
778               RI != RE;
779               RI++) {
780            slangAssert(*RI != nullptr && "cannot be NULL object");
781
782            Name = (*RI)->getName();
783            if (!Name.empty())
784              break;
785          }
786        }
787      }
788      return Name;
789    }
790    case clang::Type::Pointer: {
791      // "*" plus pointee name
792      const clang::PointerType *P = static_cast<const clang::PointerType*>(CTI);
793      const clang::Type *PT = GetPointeeType(P);
794      llvm::StringRef PointeeName;
795      if (NormalizeType(PT, PointeeName, nullptr, nullptr,
796                        NotLegacyKernelArgument)) {
797        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
798        Name[0] = '*';
799        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
800        Name[PointeeName.size() + 1] = '\0';
801        return Name;
802      }
803      break;
804    }
805    case clang::Type::ExtVector: {
806      const clang::ExtVectorType *EVT =
807              static_cast<const clang::ExtVectorType*>(CTI);
808      return RSExportVectorType::GetTypeName(EVT);
809      break;
810    }
811    case clang::Type::ConstantArray : {
812      // Construct name for a constant array is too complicated.
813      return "<ConstantArray>";
814    }
815    default: {
816      break;
817    }
818  }
819
820  return llvm::StringRef();
821}
822
823
824RSExportType *RSExportType::Create(RSContext *Context,
825                                   const clang::Type *T,
826                                   const llvm::StringRef &TypeName,
827                                   ExportKind EK) {
828  // Lookup the context to see whether the type was processed before.
829  // Newly created RSExportType will insert into context
830  // in RSExportType::RSExportType()
831  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
832
833  if (ETI != Context->export_types_end())
834    return ETI->second;
835
836  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
837
838  RSExportType *ET = nullptr;
839  switch (T->getTypeClass()) {
840    case clang::Type::Record: {
841      DataType dt = RSExportPrimitiveType::GetRSSpecificType(TypeName);
842      switch (dt) {
843        case DataTypeUnknown: {
844          // User-defined types
845          ET = RSExportRecordType::Create(Context,
846                                          T->getAsStructureType(),
847                                          TypeName);
848          break;
849        }
850        case DataTypeRSMatrix2x2: {
851          // 2 x 2 Matrix type
852          ET = RSExportMatrixType::Create(Context,
853                                          T->getAsStructureType(),
854                                          TypeName,
855                                          2);
856          break;
857        }
858        case DataTypeRSMatrix3x3: {
859          // 3 x 3 Matrix type
860          ET = RSExportMatrixType::Create(Context,
861                                          T->getAsStructureType(),
862                                          TypeName,
863                                          3);
864          break;
865        }
866        case DataTypeRSMatrix4x4: {
867          // 4 x 4 Matrix type
868          ET = RSExportMatrixType::Create(Context,
869                                          T->getAsStructureType(),
870                                          TypeName,
871                                          4);
872          break;
873        }
874        default: {
875          // Others are primitive types
876          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
877          break;
878        }
879      }
880      break;
881    }
882    case clang::Type::Builtin: {
883      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
884      break;
885    }
886    case clang::Type::Pointer: {
887      ET = RSExportPointerType::Create(Context,
888                                       static_cast<const clang::PointerType*>(CTI),
889                                       TypeName);
890      // FIXME: free the name (allocated in RSExportType::GetTypeName)
891      delete [] TypeName.data();
892      break;
893    }
894    case clang::Type::ExtVector: {
895      ET = RSExportVectorType::Create(Context,
896                                      static_cast<const clang::ExtVectorType*>(CTI),
897                                      TypeName);
898      break;
899    }
900    case clang::Type::ConstantArray: {
901      ET = RSExportConstantArrayType::Create(
902              Context,
903              static_cast<const clang::ConstantArrayType*>(CTI));
904      break;
905    }
906    default: {
907      Context->ReportError("unknown type cannot be exported: '%0'")
908          << T->getTypeClassName();
909      break;
910    }
911  }
912
913  return ET;
914}
915
916RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T,
917                                   ExportKind EK, const clang::VarDecl *VD) {
918  llvm::StringRef TypeName;
919  if (NormalizeType(T, TypeName, Context, VD, EK)) {
920    return Create(Context, T, TypeName, EK);
921  } else {
922    return nullptr;
923  }
924}
925
926RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
927                                           const clang::VarDecl *VD) {
928  return RSExportType::Create(Context, GetTypeOfDecl(VD),
929                              NotLegacyKernelArgument, VD);
930}
931
932size_t RSExportType::getStoreSize() const {
933  return getRSContext()->getDataLayout()->getTypeStoreSize(getLLVMType());
934}
935
936size_t RSExportType::getAllocSize() const {
937    return getRSContext()->getDataLayout()->getTypeAllocSize(getLLVMType());
938}
939
940RSExportType::RSExportType(RSContext *Context,
941                           ExportClass Class,
942                           const llvm::StringRef &Name)
943    : RSExportable(Context, RSExportable::EX_TYPE),
944      mClass(Class),
945      // Make a copy on Name since memory stored @Name is either allocated in
946      // ASTContext or allocated in GetTypeName which will be destroyed later.
947      mName(Name.data(), Name.size()),
948      mLLVMType(nullptr) {
949  // Don't cache the type whose name start with '<'. Those type failed to
950  // get their name since constructing their name in GetTypeName() requiring
951  // complicated work.
952  if (!IsDummyName(Name)) {
953    // TODO(zonr): Need to check whether the insertion is successful or not.
954    Context->insertExportType(llvm::StringRef(Name), this);
955  }
956
957}
958
959bool RSExportType::keep() {
960  if (!RSExportable::keep())
961    return false;
962  // Invalidate converted LLVM type.
963  mLLVMType = nullptr;
964  return true;
965}
966
967bool RSExportType::equals(const RSExportable *E) const {
968  CHECK_PARENT_EQUALITY(RSExportable, E);
969  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
970}
971
972RSExportType::~RSExportType() {
973}
974
975/************************** RSExportPrimitiveType **************************/
976llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
977RSExportPrimitiveType::RSSpecificTypeMap;
978
979bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
980  if ((T != nullptr) && (T->getTypeClass() == clang::Type::Builtin))
981    return true;
982  else
983    return false;
984}
985
986DataType
987RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
988  if (TypeName.empty())
989    return DataTypeUnknown;
990
991  if (RSSpecificTypeMap->empty()) {
992    for (int i = 0; i < MatrixAndObjectDataTypesCount; i++) {
993      (*RSSpecificTypeMap)[MatrixAndObjectDataTypes[i].name] =
994          MatrixAndObjectDataTypes[i].dataType;
995    }
996  }
997
998  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
999  if (I == RSSpecificTypeMap->end())
1000    return DataTypeUnknown;
1001  else
1002    return I->getValue();
1003}
1004
1005DataType RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
1006  T = GetCanonicalType(T);
1007  if ((T == nullptr) || (T->getTypeClass() != clang::Type::Record))
1008    return DataTypeUnknown;
1009
1010  return GetRSSpecificType( RSExportType::GetTypeName(T) );
1011}
1012
1013bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
1014    if (DT < 0 || DT >= DataTypeMax) {
1015        return false;
1016    }
1017    return gReflectionTypes[DT].category == MatrixDataType;
1018}
1019
1020bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
1021    if (DT < 0 || DT >= DataTypeMax) {
1022        return false;
1023    }
1024    return gReflectionTypes[DT].category == ObjectDataType;
1025}
1026
1027bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
1028  bool RSObjectTypeSeen = false;
1029  while (T && T->isArrayType()) {
1030    T = T->getArrayElementTypeNoTypeQual();
1031  }
1032
1033  const clang::RecordType *RT = T->getAsStructureType();
1034  if (!RT) {
1035    return false;
1036  }
1037
1038  const clang::RecordDecl *RD = RT->getDecl();
1039  if (RD) {
1040    RD = RD->getDefinition();
1041  }
1042  if (!RD) {
1043    return false;
1044  }
1045
1046  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1047         FE = RD->field_end();
1048       FI != FE;
1049       FI++) {
1050    // We just look through all field declarations to see if we find a
1051    // declaration for an RS object type (or an array of one).
1052    const clang::FieldDecl *FD = *FI;
1053    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1054    while (FT && FT->isArrayType()) {
1055      FT = FT->getArrayElementTypeNoTypeQual();
1056    }
1057
1058    DataType DT = GetRSSpecificType(FT);
1059    if (IsRSObjectType(DT)) {
1060      // RS object types definitely need to be zero-initialized
1061      RSObjectTypeSeen = true;
1062    } else {
1063      switch (DT) {
1064        case DataTypeRSMatrix2x2:
1065        case DataTypeRSMatrix3x3:
1066        case DataTypeRSMatrix4x4:
1067          // Matrix types should get zero-initialized as well
1068          RSObjectTypeSeen = true;
1069          break;
1070        default:
1071          // Ignore all other primitive types
1072          break;
1073      }
1074      while (FT && FT->isArrayType()) {
1075        FT = FT->getArrayElementTypeNoTypeQual();
1076      }
1077      if (FT->isStructureType()) {
1078        // Recursively handle structs of structs (even though these can't
1079        // be exported, it is possible for a user to have them internally).
1080        RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
1081      }
1082    }
1083  }
1084
1085  return RSObjectTypeSeen;
1086}
1087
1088size_t RSExportPrimitiveType::GetElementSizeInBits(const RSExportPrimitiveType *EPT) {
1089  int type = EPT->getType();
1090  slangAssert((type > DataTypeUnknown && type < DataTypeMax) &&
1091              "RSExportPrimitiveType::GetElementSizeInBits : unknown data type");
1092  // All RS object types are 256 bits in 64-bit RS.
1093  if (EPT->isRSObjectType() && EPT->getRSContext()->is64Bit()) {
1094    return 256;
1095  }
1096  return gReflectionTypes[type].size_in_bits;
1097}
1098
1099DataType
1100RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
1101  if (T == nullptr)
1102    return DataTypeUnknown;
1103
1104  switch (T->getTypeClass()) {
1105    case clang::Type::Builtin: {
1106      const clang::BuiltinType *BT =
1107              static_cast<const clang::BuiltinType*>(T->getCanonicalTypeInternal().getTypePtr());
1108      BuiltinInfo *info = FindBuiltinType(BT->getKind());
1109      if (info != nullptr) {
1110        return info->type;
1111      }
1112      // The size of type WChar depend on platform so we abandon the support
1113      // to them.
1114      Context->ReportError("built-in type cannot be exported: '%0'")
1115          << T->getTypeClassName();
1116      break;
1117    }
1118    case clang::Type::Record: {
1119      // must be RS object type
1120      return RSExportPrimitiveType::GetRSSpecificType(T);
1121    }
1122    default: {
1123      Context->ReportError("primitive type cannot be exported: '%0'")
1124          << T->getTypeClassName();
1125      break;
1126    }
1127  }
1128
1129  return DataTypeUnknown;
1130}
1131
1132RSExportPrimitiveType
1133*RSExportPrimitiveType::Create(RSContext *Context,
1134                               const clang::Type *T,
1135                               const llvm::StringRef &TypeName,
1136                               bool Normalized) {
1137  DataType DT = GetDataType(Context, T);
1138
1139  if ((DT == DataTypeUnknown) || TypeName.empty())
1140    return nullptr;
1141  else
1142    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
1143                                     DT, Normalized);
1144}
1145
1146RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
1147                                                     const clang::Type *T) {
1148  llvm::StringRef TypeName;
1149  if (RSExportType::NormalizeType(T, TypeName, Context, nullptr,
1150                                  NotLegacyKernelArgument) &&
1151      IsPrimitiveType(T)) {
1152    return Create(Context, T, TypeName);
1153  } else {
1154    return nullptr;
1155  }
1156}
1157
1158llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
1159  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1160
1161  if (isRSObjectType()) {
1162    // struct {
1163    //   int *p;
1164    // } __attribute__((packed, aligned(pointer_size)))
1165    //
1166    // which is
1167    //
1168    // <{ [1 x i32] }> in LLVM
1169    //
1170    std::vector<llvm::Type *> Elements;
1171    if (getRSContext()->is64Bit()) {
1172      // 64-bit path
1173      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt64Ty(C), 4));
1174      return llvm::StructType::get(C, Elements, true);
1175    } else {
1176      // 32-bit legacy path
1177      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
1178      return llvm::StructType::get(C, Elements, true);
1179    }
1180  }
1181
1182  switch (mType) {
1183    case DataTypeFloat16: {
1184      return llvm::Type::getHalfTy(C);
1185      break;
1186    }
1187    case DataTypeFloat32: {
1188      return llvm::Type::getFloatTy(C);
1189      break;
1190    }
1191    case DataTypeFloat64: {
1192      return llvm::Type::getDoubleTy(C);
1193      break;
1194    }
1195    case DataTypeBoolean: {
1196      return llvm::Type::getInt1Ty(C);
1197      break;
1198    }
1199    case DataTypeSigned8:
1200    case DataTypeUnsigned8: {
1201      return llvm::Type::getInt8Ty(C);
1202      break;
1203    }
1204    case DataTypeSigned16:
1205    case DataTypeUnsigned16:
1206    case DataTypeUnsigned565:
1207    case DataTypeUnsigned5551:
1208    case DataTypeUnsigned4444: {
1209      return llvm::Type::getInt16Ty(C);
1210      break;
1211    }
1212    case DataTypeSigned32:
1213    case DataTypeUnsigned32: {
1214      return llvm::Type::getInt32Ty(C);
1215      break;
1216    }
1217    case DataTypeSigned64:
1218    case DataTypeUnsigned64: {
1219      return llvm::Type::getInt64Ty(C);
1220      break;
1221    }
1222    default: {
1223      slangAssert(false && "Unknown data type");
1224    }
1225  }
1226
1227  return nullptr;
1228}
1229
1230bool RSExportPrimitiveType::equals(const RSExportable *E) const {
1231  CHECK_PARENT_EQUALITY(RSExportType, E);
1232  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
1233}
1234
1235RSReflectionType *RSExportPrimitiveType::getRSReflectionType(DataType DT) {
1236  if (DT > DataTypeUnknown && DT < DataTypeMax) {
1237    return &gReflectionTypes[DT];
1238  } else {
1239    return nullptr;
1240  }
1241}
1242
1243/**************************** RSExportPointerType ****************************/
1244
1245RSExportPointerType
1246*RSExportPointerType::Create(RSContext *Context,
1247                             const clang::PointerType *PT,
1248                             const llvm::StringRef &TypeName) {
1249  const clang::Type *PointeeType = GetPointeeType(PT);
1250  const RSExportType *PointeeET;
1251
1252  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
1253    PointeeET = RSExportType::Create(Context, PointeeType,
1254                                     NotLegacyKernelArgument);
1255  } else {
1256    // Double or higher dimension of pointer, export as int*
1257    PointeeET = RSExportPrimitiveType::Create(Context,
1258                    Context->getASTContext().IntTy.getTypePtr());
1259  }
1260
1261  if (PointeeET == nullptr) {
1262    // Error diagnostic is emitted for corresponding pointee type
1263    return nullptr;
1264  }
1265
1266  return new RSExportPointerType(Context, TypeName, PointeeET);
1267}
1268
1269llvm::Type *RSExportPointerType::convertToLLVMType() const {
1270  llvm::Type *PointeeType = mPointeeType->getLLVMType();
1271  return llvm::PointerType::getUnqual(PointeeType);
1272}
1273
1274bool RSExportPointerType::keep() {
1275  if (!RSExportType::keep())
1276    return false;
1277  const_cast<RSExportType*>(mPointeeType)->keep();
1278  return true;
1279}
1280
1281bool RSExportPointerType::equals(const RSExportable *E) const {
1282  CHECK_PARENT_EQUALITY(RSExportType, E);
1283  return (static_cast<const RSExportPointerType*>(E)
1284              ->getPointeeType()->equals(getPointeeType()));
1285}
1286
1287/***************************** RSExportVectorType *****************************/
1288llvm::StringRef
1289RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
1290  const clang::Type *ElementType = GetExtVectorElementType(EVT);
1291  llvm::StringRef name;
1292
1293  if ((ElementType->getTypeClass() != clang::Type::Builtin))
1294    return name;
1295
1296  const clang::BuiltinType *BT =
1297          static_cast<const clang::BuiltinType*>(
1298              ElementType->getCanonicalTypeInternal().getTypePtr());
1299
1300  if ((EVT->getNumElements() < 1) ||
1301      (EVT->getNumElements() > 4))
1302    return name;
1303
1304  BuiltinInfo *info = FindBuiltinType(BT->getKind());
1305  if (info != nullptr) {
1306    int I = EVT->getNumElements() - 1;
1307    if (I < kMaxVectorSize) {
1308      name = info->cname[I];
1309    } else {
1310      slangAssert(false && "Max vector is 4");
1311    }
1312  }
1313  return name;
1314}
1315
1316RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
1317                                               const clang::ExtVectorType *EVT,
1318                                               const llvm::StringRef &TypeName,
1319                                               bool Normalized) {
1320  slangAssert(EVT != nullptr && EVT->getTypeClass() == clang::Type::ExtVector);
1321
1322  const clang::Type *ElementType = GetExtVectorElementType(EVT);
1323  DataType DT = RSExportPrimitiveType::GetDataType(Context, ElementType);
1324
1325  if (DT != DataTypeUnknown)
1326    return new RSExportVectorType(Context,
1327                                  TypeName,
1328                                  DT,
1329                                  Normalized,
1330                                  EVT->getNumElements());
1331  else
1332    return nullptr;
1333}
1334
1335llvm::Type *RSExportVectorType::convertToLLVMType() const {
1336  llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
1337  return llvm::VectorType::get(ElementType, getNumElement());
1338}
1339
1340bool RSExportVectorType::equals(const RSExportable *E) const {
1341  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
1342  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
1343              == getNumElement());
1344}
1345
1346/***************************** RSExportMatrixType *****************************/
1347RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
1348                                               const clang::RecordType *RT,
1349                                               const llvm::StringRef &TypeName,
1350                                               unsigned Dim) {
1351  slangAssert((RT != nullptr) && (RT->getTypeClass() == clang::Type::Record));
1352  slangAssert((Dim > 1) && "Invalid dimension of matrix");
1353
1354  // Check whether the struct rs_matrix is in our expected form (but assume it's
1355  // correct if we're not sure whether it's correct or not)
1356  const clang::RecordDecl* RD = RT->getDecl();
1357  RD = RD->getDefinition();
1358  if (RD != nullptr) {
1359    // Find definition, perform further examination
1360    if (RD->field_empty()) {
1361      Context->ReportError(
1362          RD->getLocation(),
1363          "invalid matrix struct: must have 1 field for saving values: '%0'")
1364          << RD->getName();
1365      return nullptr;
1366    }
1367
1368    clang::RecordDecl::field_iterator FIT = RD->field_begin();
1369    const clang::FieldDecl *FD = *FIT;
1370    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1371    if ((FT == nullptr) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
1372      Context->ReportError(RD->getLocation(),
1373                           "invalid matrix struct: first field should"
1374                           " be an array with constant size: '%0'")
1375          << RD->getName();
1376      return nullptr;
1377    }
1378    const clang::ConstantArrayType *CAT =
1379      static_cast<const clang::ConstantArrayType *>(FT);
1380    const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1381    if ((ElementType == nullptr) ||
1382        (ElementType->getTypeClass() != clang::Type::Builtin) ||
1383        (static_cast<const clang::BuiltinType *>(ElementType)->getKind() !=
1384         clang::BuiltinType::Float)) {
1385      Context->ReportError(RD->getLocation(),
1386                           "invalid matrix struct: first field "
1387                           "should be a float array: '%0'")
1388          << RD->getName();
1389      return nullptr;
1390    }
1391
1392    if (CAT->getSize() != Dim * Dim) {
1393      Context->ReportError(RD->getLocation(),
1394                           "invalid matrix struct: first field "
1395                           "should be an array with size %0: '%1'")
1396          << (Dim * Dim) << (RD->getName());
1397      return nullptr;
1398    }
1399
1400    FIT++;
1401    if (FIT != RD->field_end()) {
1402      Context->ReportError(RD->getLocation(),
1403                           "invalid matrix struct: must have "
1404                           "exactly 1 field: '%0'")
1405          << RD->getName();
1406      return nullptr;
1407    }
1408  }
1409
1410  return new RSExportMatrixType(Context, TypeName, Dim);
1411}
1412
1413llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1414  // Construct LLVM type:
1415  // struct {
1416  //  float X[mDim * mDim];
1417  // }
1418
1419  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1420  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1421                                            mDim * mDim);
1422  return llvm::StructType::get(C, X, false);
1423}
1424
1425bool RSExportMatrixType::equals(const RSExportable *E) const {
1426  CHECK_PARENT_EQUALITY(RSExportType, E);
1427  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1428}
1429
1430/************************* RSExportConstantArrayType *************************/
1431RSExportConstantArrayType
1432*RSExportConstantArrayType::Create(RSContext *Context,
1433                                   const clang::ConstantArrayType *CAT) {
1434  slangAssert(CAT != nullptr && CAT->getTypeClass() == clang::Type::ConstantArray);
1435
1436  slangAssert((CAT->getSize().getActiveBits() < 32) && "array too large");
1437
1438  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1439  slangAssert((Size > 0) && "Constant array should have size greater than 0");
1440
1441  const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1442  RSExportType *ElementET = RSExportType::Create(Context, ElementType,
1443                                                 NotLegacyKernelArgument);
1444
1445  if (ElementET == nullptr) {
1446    return nullptr;
1447  }
1448
1449  return new RSExportConstantArrayType(Context,
1450                                       ElementET,
1451                                       Size);
1452}
1453
1454llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1455  return llvm::ArrayType::get(mElementType->getLLVMType(), getNumElement());
1456}
1457
1458bool RSExportConstantArrayType::keep() {
1459  if (!RSExportType::keep())
1460    return false;
1461  const_cast<RSExportType*>(mElementType)->keep();
1462  return true;
1463}
1464
1465bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1466  CHECK_PARENT_EQUALITY(RSExportType, E);
1467  const RSExportConstantArrayType *RHS =
1468      static_cast<const RSExportConstantArrayType*>(E);
1469  return ((getNumElement() == RHS->getNumElement()) &&
1470          (getElementType()->equals(RHS->getElementType())));
1471}
1472
1473/**************************** RSExportRecordType ****************************/
1474RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1475                                               const clang::RecordType *RT,
1476                                               const llvm::StringRef &TypeName,
1477                                               bool mIsArtificial) {
1478  slangAssert(RT != nullptr && RT->getTypeClass() == clang::Type::Record);
1479
1480  const clang::RecordDecl *RD = RT->getDecl();
1481  slangAssert(RD->isStruct());
1482
1483  RD = RD->getDefinition();
1484  if (RD == nullptr) {
1485    slangAssert(false && "struct is not defined in this module");
1486    return nullptr;
1487  }
1488
1489  // Struct layout construct by clang. We rely on this for obtaining the
1490  // alloc size of a struct and offset of every field in that struct.
1491  const clang::ASTRecordLayout *RL =
1492      &Context->getASTContext().getASTRecordLayout(RD);
1493  slangAssert((RL != nullptr) &&
1494      "Failed to retrieve the struct layout from Clang.");
1495
1496  RSExportRecordType *ERT =
1497      new RSExportRecordType(Context,
1498                             TypeName,
1499                             RD->hasAttr<clang::PackedAttr>(),
1500                             mIsArtificial,
1501                             RL->getDataSize().getQuantity(),
1502                             RL->getSize().getQuantity());
1503  unsigned int Index = 0;
1504
1505  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1506           FE = RD->field_end();
1507       FI != FE;
1508       FI++, Index++) {
1509
1510    // FIXME: All fields should be primitive type
1511    slangAssert(FI->getKind() == clang::Decl::Field);
1512    clang::FieldDecl *FD = *FI;
1513
1514    if (FD->isBitField()) {
1515      return nullptr;
1516    }
1517
1518    // Type
1519    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1520
1521    if (ET != nullptr) {
1522      ERT->mFields.push_back(
1523          new Field(ET, FD->getName(), ERT,
1524                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1525    } else {
1526      Context->ReportError(RD->getLocation(),
1527                           "field type cannot be exported: '%0.%1'")
1528          << RD->getName() << FD->getName();
1529      return nullptr;
1530    }
1531  }
1532
1533  return ERT;
1534}
1535
1536llvm::Type *RSExportRecordType::convertToLLVMType() const {
1537  // Create an opaque type since struct may reference itself recursively.
1538
1539  // TODO(sliao): LLVM took out the OpaqueType. Any other to migrate to?
1540  std::vector<llvm::Type*> FieldTypes;
1541
1542  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1543       FI != FE;
1544       FI++) {
1545    const Field *F = *FI;
1546    const RSExportType *FET = F->getType();
1547
1548    FieldTypes.push_back(FET->getLLVMType());
1549  }
1550
1551  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1552                                               FieldTypes,
1553                                               mIsPacked);
1554  if (ST != nullptr) {
1555    return ST;
1556  } else {
1557    return nullptr;
1558  }
1559}
1560
1561bool RSExportRecordType::keep() {
1562  if (!RSExportType::keep())
1563    return false;
1564  for (std::list<const Field*>::iterator I = mFields.begin(),
1565          E = mFields.end();
1566       I != E;
1567       I++) {
1568    const_cast<RSExportType*>((*I)->getType())->keep();
1569  }
1570  return true;
1571}
1572
1573bool RSExportRecordType::equals(const RSExportable *E) const {
1574  CHECK_PARENT_EQUALITY(RSExportType, E);
1575
1576  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1577
1578  if (ERT->getFields().size() != getFields().size())
1579    return false;
1580
1581  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1582
1583  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1584    if (!(*AI)->getType()->equals((*BI)->getType()))
1585      return false;
1586    AI++;
1587    BI++;
1588  }
1589
1590  return true;
1591}
1592
1593void RSExportType::convertToRTD(RSReflectionTypeData *rtd) const {
1594    memset(rtd, 0, sizeof(*rtd));
1595    rtd->vecSize = 1;
1596
1597    switch(getClass()) {
1598    case RSExportType::ExportClassPrimitive: {
1599            const RSExportPrimitiveType *EPT = static_cast<const RSExportPrimitiveType*>(this);
1600            rtd->type = RSExportPrimitiveType::getRSReflectionType(EPT);
1601            return;
1602        }
1603    case RSExportType::ExportClassPointer: {
1604            const RSExportPointerType *EPT = static_cast<const RSExportPointerType*>(this);
1605            const RSExportType *PointeeType = EPT->getPointeeType();
1606            PointeeType->convertToRTD(rtd);
1607            rtd->isPointer = true;
1608            return;
1609        }
1610    case RSExportType::ExportClassVector: {
1611            const RSExportVectorType *EVT = static_cast<const RSExportVectorType*>(this);
1612            rtd->type = EVT->getRSReflectionType(EVT);
1613            rtd->vecSize = EVT->getNumElement();
1614            return;
1615        }
1616    case RSExportType::ExportClassMatrix: {
1617            const RSExportMatrixType *EMT = static_cast<const RSExportMatrixType*>(this);
1618            unsigned Dim = EMT->getDim();
1619            slangAssert((Dim >= 2) && (Dim <= 4));
1620            rtd->type = &gReflectionTypes[15 + Dim-2];
1621            return;
1622        }
1623    case RSExportType::ExportClassConstantArray: {
1624            const RSExportConstantArrayType* CAT =
1625              static_cast<const RSExportConstantArrayType*>(this);
1626            CAT->getElementType()->convertToRTD(rtd);
1627            rtd->arraySize = CAT->getNumElement();
1628            return;
1629        }
1630    case RSExportType::ExportClassRecord: {
1631            slangAssert(!"RSExportType::ExportClassRecord not implemented");
1632            return;// RS_TYPE_CLASS_NAME_PREFIX + ET->getName() + ".Item";
1633        }
1634    default: {
1635            slangAssert(false && "Unknown class of type");
1636        }
1637    }
1638}
1639
1640
1641}  // namespace slang
1642