slang_rs_export_type.cpp revision 33ea573b6df7b7fe48d2b68d4c479f33082e3c0d
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/ASTContext.h"
23#include "clang/AST/Attr.h"
24#include "clang/AST/RecordLayout.h"
25
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/DerivedTypes.h"
29#include "llvm/IR/Type.h"
30
31#include "slang_assert.h"
32#include "slang_rs_context.h"
33#include "slang_rs_export_element.h"
34#include "slang_version.h"
35
36#define CHECK_PARENT_EQUALITY(ParentClass, E) \
37  if (!ParentClass::equals(E))                \
38    return false;
39
40namespace slang {
41
42namespace {
43
44/* For the data types we support, their category, names, and size (in bits).
45 *
46 * IMPORTANT: The data types in this table should be at the same index
47 * as specified by the corresponding DataType enum.
48 */
49static RSReflectionType gReflectionTypes[] = {
50    {PrimitiveDataType, "FLOAT_16", "F16", 16, "half", "short", "Half", "Half", false},
51    {PrimitiveDataType, "FLOAT_32", "F32", 32, "float", "float", "Float", "Float", false},
52    {PrimitiveDataType, "FLOAT_64", "F64", 64, "double", "double", "Double", "Double",false},
53    {PrimitiveDataType, "SIGNED_8", "I8", 8, "int8_t", "byte", "Byte", "Byte", false},
54    {PrimitiveDataType, "SIGNED_16", "I16", 16, "int16_t", "short", "Short", "Short", false},
55    {PrimitiveDataType, "SIGNED_32", "I32", 32, "int32_t", "int", "Int", "Int", false},
56    {PrimitiveDataType, "SIGNED_64", "I64", 64, "int64_t", "long", "Long", "Long", false},
57    {PrimitiveDataType, "UNSIGNED_8", "U8", 8, "uint8_t", "short", "UByte", "Short", true},
58    {PrimitiveDataType, "UNSIGNED_16", "U16", 16, "uint16_t", "int", "UShort", "Int", true},
59    {PrimitiveDataType, "UNSIGNED_32", "U32", 32, "uint32_t", "long", "UInt", "Long", true},
60    {PrimitiveDataType, "UNSIGNED_64", "U64", 64, "uint64_t", "long", "ULong", "Long", false},
61
62    {PrimitiveDataType, "BOOLEAN", "BOOLEAN", 8, "bool", "boolean", nullptr, nullptr, false},
63
64    {PrimitiveDataType, "UNSIGNED_5_6_5", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
65    {PrimitiveDataType, "UNSIGNED_5_5_5_1", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
66    {PrimitiveDataType, "UNSIGNED_4_4_4_4", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
67
68    {MatrixDataType, "MATRIX_2X2", nullptr, 4*32, "rsMatrix_2x2", "Matrix2f", nullptr, nullptr, false},
69    {MatrixDataType, "MATRIX_3X3", nullptr, 9*32, "rsMatrix_3x3", "Matrix3f", nullptr, nullptr, false},
70    {MatrixDataType, "MATRIX_4X4", nullptr, 16*32, "rsMatrix_4x4", "Matrix4f", nullptr, nullptr, false},
71
72    // RS object types are 32 bits in 32-bit RS, but 256 bits in 64-bit RS.
73    // This is handled specially by the GetSizeInBits() method.
74    {ObjectDataType, "RS_ELEMENT", "ELEMENT", 32, "Element", "Element", nullptr, nullptr, false},
75    {ObjectDataType, "RS_TYPE", "TYPE", 32, "Type", "Type", nullptr, nullptr, false},
76    {ObjectDataType, "RS_ALLOCATION", "ALLOCATION", 32, "Allocation", "Allocation", nullptr, nullptr, false},
77    {ObjectDataType, "RS_SAMPLER", "SAMPLER", 32, "Sampler", "Sampler", nullptr, nullptr, false},
78    {ObjectDataType, "RS_SCRIPT", "SCRIPT", 32, "Script", "Script", nullptr, nullptr, false},
79    {ObjectDataType, "RS_MESH", "MESH", 32, "Mesh", "Mesh", nullptr, nullptr, false},
80    {ObjectDataType, "RS_PATH", "PATH", 32, "Path", "Path", nullptr, nullptr, false},
81
82    {ObjectDataType, "RS_PROGRAM_FRAGMENT", "PROGRAM_FRAGMENT", 32, "ProgramFragment", "ProgramFragment", nullptr, nullptr, false},
83    {ObjectDataType, "RS_PROGRAM_VERTEX", "PROGRAM_VERTEX", 32, "ProgramVertex", "ProgramVertex", nullptr, nullptr, false},
84    {ObjectDataType, "RS_PROGRAM_RASTER", "PROGRAM_RASTER", 32, "ProgramRaster", "ProgramRaster", nullptr, nullptr, false},
85    {ObjectDataType, "RS_PROGRAM_STORE", "PROGRAM_STORE", 32, "ProgramStore", "ProgramStore", nullptr, nullptr, false},
86    {ObjectDataType, "RS_FONT", "FONT", 32, "Font", "Font", nullptr, nullptr, false}
87};
88
89const int kMaxVectorSize = 4;
90
91struct BuiltinInfo {
92  clang::BuiltinType::Kind builtinTypeKind;
93  DataType type;
94  /* TODO If we return std::string instead of llvm::StringRef, we could build
95   * the name instead of duplicating the entries.
96   */
97  const char *cname[kMaxVectorSize];
98};
99
100
101BuiltinInfo BuiltinInfoTable[] = {
102    {clang::BuiltinType::Bool, DataTypeBoolean,
103     {"bool", "bool2", "bool3", "bool4"}},
104    {clang::BuiltinType::Char_U, DataTypeUnsigned8,
105     {"uchar", "uchar2", "uchar3", "uchar4"}},
106    {clang::BuiltinType::UChar, DataTypeUnsigned8,
107     {"uchar", "uchar2", "uchar3", "uchar4"}},
108    {clang::BuiltinType::Char16, DataTypeSigned16,
109     {"short", "short2", "short3", "short4"}},
110    {clang::BuiltinType::Char32, DataTypeSigned32,
111     {"int", "int2", "int3", "int4"}},
112    {clang::BuiltinType::UShort, DataTypeUnsigned16,
113     {"ushort", "ushort2", "ushort3", "ushort4"}},
114    {clang::BuiltinType::UInt, DataTypeUnsigned32,
115     {"uint", "uint2", "uint3", "uint4"}},
116    {clang::BuiltinType::ULong, DataTypeUnsigned64,
117     {"ulong", "ulong2", "ulong3", "ulong4"}},
118    {clang::BuiltinType::ULongLong, DataTypeUnsigned64,
119     {"ulong", "ulong2", "ulong3", "ulong4"}},
120
121    {clang::BuiltinType::Char_S, DataTypeSigned8,
122     {"char", "char2", "char3", "char4"}},
123    {clang::BuiltinType::SChar, DataTypeSigned8,
124     {"char", "char2", "char3", "char4"}},
125    {clang::BuiltinType::Short, DataTypeSigned16,
126     {"short", "short2", "short3", "short4"}},
127    {clang::BuiltinType::Int, DataTypeSigned32,
128     {"int", "int2", "int3", "int4"}},
129    {clang::BuiltinType::Long, DataTypeSigned64,
130     {"long", "long2", "long3", "long4"}},
131    {clang::BuiltinType::LongLong, DataTypeSigned64,
132     {"long", "long2", "long3", "long4"}},
133    {clang::BuiltinType::Half, DataTypeFloat16,
134     {"half", "half2", "half3", "half4"}},
135    {clang::BuiltinType::Float, DataTypeFloat32,
136     {"float", "float2", "float3", "float4"}},
137    {clang::BuiltinType::Double, DataTypeFloat64,
138     {"double", "double2", "double3", "double4"}},
139};
140const int BuiltinInfoTableCount = sizeof(BuiltinInfoTable) / sizeof(BuiltinInfoTable[0]);
141
142struct NameAndPrimitiveType {
143  const char *name;
144  DataType dataType;
145};
146
147static NameAndPrimitiveType MatrixAndObjectDataTypes[] = {
148    {"rs_matrix2x2", DataTypeRSMatrix2x2},
149    {"rs_matrix3x3", DataTypeRSMatrix3x3},
150    {"rs_matrix4x4", DataTypeRSMatrix4x4},
151    {"rs_element", DataTypeRSElement},
152    {"rs_type", DataTypeRSType},
153    {"rs_allocation", DataTypeRSAllocation},
154    {"rs_sampler", DataTypeRSSampler},
155    {"rs_script", DataTypeRSScript},
156    {"rs_mesh", DataTypeRSMesh},
157    {"rs_path", DataTypeRSPath},
158    {"rs_program_fragment", DataTypeRSProgramFragment},
159    {"rs_program_vertex", DataTypeRSProgramVertex},
160    {"rs_program_raster", DataTypeRSProgramRaster},
161    {"rs_program_store", DataTypeRSProgramStore},
162    {"rs_font", DataTypeRSFont},
163};
164
165const int MatrixAndObjectDataTypesCount =
166    sizeof(MatrixAndObjectDataTypes) / sizeof(MatrixAndObjectDataTypes[0]);
167
168static const clang::Type *TypeExportableHelper(
169    const clang::Type *T,
170    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
171    slang::RSContext *Context,
172    const clang::VarDecl *VD,
173    const clang::RecordDecl *TopLevelRecord);
174
175template <unsigned N>
176static void ReportTypeError(slang::RSContext *Context,
177                            const clang::NamedDecl *ND,
178                            const clang::RecordDecl *TopLevelRecord,
179                            const char (&Message)[N],
180                            unsigned int TargetAPI = 0) {
181  // Attempt to use the type declaration first (if we have one).
182  // Fall back to the variable definition, if we are looking at something
183  // like an array declaration that can't be exported.
184  if (TopLevelRecord) {
185    Context->ReportError(TopLevelRecord->getLocation(), Message)
186        << TopLevelRecord->getName() << TargetAPI;
187  } else if (ND) {
188    Context->ReportError(ND->getLocation(), Message) << ND->getName()
189                                                     << TargetAPI;
190  } else {
191    slangAssert(false && "Variables should be validated before exporting");
192  }
193}
194
195static const clang::Type *ConstantArrayTypeExportableHelper(
196    const clang::ConstantArrayType *CAT,
197    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
198    slang::RSContext *Context,
199    const clang::VarDecl *VD,
200    const clang::RecordDecl *TopLevelRecord) {
201  // Check element type
202  const clang::Type *ElementType = GetConstantArrayElementType(CAT);
203  if (ElementType->isArrayType()) {
204    ReportTypeError(Context, VD, TopLevelRecord,
205                    "multidimensional arrays cannot be exported: '%0'");
206    return nullptr;
207  } else if (ElementType->isExtVectorType()) {
208    const clang::ExtVectorType *EVT =
209        static_cast<const clang::ExtVectorType*>(ElementType);
210    unsigned numElements = EVT->getNumElements();
211
212    const clang::Type *BaseElementType = GetExtVectorElementType(EVT);
213    if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
214      ReportTypeError(Context, VD, TopLevelRecord,
215        "vectors of non-primitive types cannot be exported: '%0'");
216      return nullptr;
217    }
218
219    if (numElements == 3 && CAT->getSize() != 1) {
220      ReportTypeError(Context, VD, TopLevelRecord,
221        "arrays of width 3 vector types cannot be exported: '%0'");
222      return nullptr;
223    }
224  }
225
226  if (TypeExportableHelper(ElementType, SPS, Context, VD,
227                           TopLevelRecord) == nullptr) {
228    return nullptr;
229  } else {
230    return CAT;
231  }
232}
233
234BuiltinInfo *FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind) {
235  for (int i = 0; i < BuiltinInfoTableCount; i++) {
236    if (builtinTypeKind == BuiltinInfoTable[i].builtinTypeKind) {
237      return &BuiltinInfoTable[i];
238    }
239  }
240  return nullptr;
241}
242
243static const clang::Type *TypeExportableHelper(
244    clang::Type const *T,
245    llvm::SmallPtrSet<clang::Type const *, 8> &SPS,
246    slang::RSContext *Context,
247    clang::VarDecl const *VD,
248    clang::RecordDecl const *TopLevelRecord) {
249  // Normalize first
250  if ((T = GetCanonicalType(T)) == nullptr)
251    return nullptr;
252
253  if (SPS.count(T))
254    return T;
255
256  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
257
258  switch (T->getTypeClass()) {
259    case clang::Type::Builtin: {
260      const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
261      return FindBuiltinType(BT->getKind()) == nullptr ? nullptr : T;
262    }
263    case clang::Type::Record: {
264      if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
265        return T;  // RS object type, no further checks are needed
266      }
267
268      // Check internal struct
269      if (T->isUnionType()) {
270        ReportTypeError(Context, VD, T->getAsUnionType()->getDecl(),
271                        "unions cannot be exported: '%0'");
272        return nullptr;
273      } else if (!T->isStructureType()) {
274        slangAssert(false && "Unknown type cannot be exported");
275        return nullptr;
276      }
277
278      clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
279      if (RD != nullptr) {
280        RD = RD->getDefinition();
281        if (RD == nullptr) {
282          ReportTypeError(Context, nullptr, T->getAsStructureType()->getDecl(),
283                          "struct is not defined in this module");
284          return nullptr;
285        }
286      }
287
288      if (!TopLevelRecord) {
289        TopLevelRecord = RD;
290      }
291      if (RD->getName().empty()) {
292        ReportTypeError(Context, nullptr, RD,
293                        "anonymous structures cannot be exported");
294        return nullptr;
295      }
296
297      // Fast check
298      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
299        return nullptr;
300
301      // Insert myself into checking set
302      SPS.insert(T);
303
304      // Check all element
305      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
306               FE = RD->field_end();
307           FI != FE;
308           FI++) {
309        const clang::FieldDecl *FD = *FI;
310        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
311        FT = GetCanonicalType(FT);
312
313        if (!TypeExportableHelper(FT, SPS, Context, VD, TopLevelRecord)) {
314          return nullptr;
315        }
316
317        // We don't support bit fields yet
318        //
319        // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
320        if (FD->isBitField()) {
321          Context->ReportError(
322              FD->getLocation(),
323              "bit fields are not able to be exported: '%0.%1'")
324              << RD->getName() << FD->getName();
325          return nullptr;
326        }
327      }
328
329      return T;
330    }
331    case clang::Type::Pointer: {
332      if (TopLevelRecord) {
333        ReportTypeError(Context, VD, TopLevelRecord,
334            "structures containing pointers cannot be used as the type of "
335            "an exported global variable or the parameter to an exported "
336            "function: '%0'");
337        return nullptr;
338      }
339
340      const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
341      const clang::Type *PointeeType = GetPointeeType(PT);
342
343      if (PointeeType->getTypeClass() == clang::Type::Pointer) {
344        ReportTypeError(Context, VD, TopLevelRecord,
345            "multiple levels of pointers cannot be exported: '%0'");
346        return nullptr;
347      }
348      // We don't support pointer with array-type pointee or unsupported pointee
349      // type
350      if (PointeeType->isArrayType() ||
351          (TypeExportableHelper(PointeeType, SPS, Context, VD,
352                                TopLevelRecord) == nullptr))
353        return nullptr;
354      else
355        return T;
356    }
357    case clang::Type::ExtVector: {
358      const clang::ExtVectorType *EVT =
359              static_cast<const clang::ExtVectorType*>(CTI);
360      // Only vector with size 2, 3 and 4 are supported.
361      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
362        return nullptr;
363
364      // Check base element type
365      const clang::Type *ElementType = GetExtVectorElementType(EVT);
366
367      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
368          (TypeExportableHelper(ElementType, SPS, Context, VD,
369                                TopLevelRecord) == nullptr))
370        return nullptr;
371      else
372        return T;
373    }
374    case clang::Type::ConstantArray: {
375      const clang::ConstantArrayType *CAT =
376              static_cast<const clang::ConstantArrayType*>(CTI);
377
378      return ConstantArrayTypeExportableHelper(CAT, SPS, Context, VD,
379                                               TopLevelRecord);
380    }
381    case clang::Type::Enum: {
382      // FIXME: We currently convert enums to integers, rather than reflecting
383      // a more complete (and nicer type-safe Java version).
384      return Context->getASTContext().IntTy.getTypePtr();
385    }
386    default: {
387      slangAssert(false && "Unknown type cannot be validated");
388      return nullptr;
389    }
390  }
391}
392
393// Return the type that can be used to create RSExportType, will always return
394// the canonical type.
395//
396// If the Type T is not exportable, this function returns nullptr. DiagEngine is
397// used to generate proper Clang diagnostic messages when a non-exportable type
398// is detected. TopLevelRecord is used to capture the highest struct (in the
399// case of a nested hierarchy) for detecting other types that cannot be exported
400// (mostly pointers within a struct).
401static const clang::Type *TypeExportable(const clang::Type *T,
402                                         slang::RSContext *Context,
403                                         const clang::VarDecl *VD) {
404  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
405      llvm::SmallPtrSet<const clang::Type*, 8>();
406
407  return TypeExportableHelper(T, SPS, Context, VD, nullptr);
408}
409
410static bool ValidateRSObjectInVarDecl(slang::RSContext *Context,
411                                      clang::VarDecl *VD, bool InCompositeType,
412                                      unsigned int TargetAPI) {
413  if (TargetAPI < SLANG_JB_TARGET_API) {
414    // Only if we are already in a composite type (like an array or structure).
415    if (InCompositeType) {
416      // Only if we are actually exported (i.e. non-static).
417      if (VD->hasLinkage() &&
418          (VD->getFormalLinkage() == clang::ExternalLinkage)) {
419        // Only if we are not a pointer to an object.
420        const clang::Type *T = GetCanonicalType(VD->getType().getTypePtr());
421        if (T->getTypeClass() != clang::Type::Pointer) {
422          ReportTypeError(Context, VD, nullptr,
423                          "arrays/structures containing RS object types "
424                          "cannot be exported in target API < %1: '%0'",
425                          SLANG_JB_TARGET_API);
426          return false;
427        }
428      }
429    }
430  }
431
432  return true;
433}
434
435// Helper function for ValidateType(). We do a recursive descent on the
436// type hierarchy to ensure that we can properly export/handle the
437// declaration.
438// \return true if the variable declaration is valid,
439//         false if it is invalid (along with proper diagnostics).
440//
441// C - ASTContext (for diagnostics + builtin types).
442// T - sub-type that we are validating.
443// ND - (optional) top-level named declaration that we are validating.
444// SPS - set of types we have already seen/validated.
445// InCompositeType - true if we are within an outer composite type.
446// UnionDecl - set if we are in a sub-type of a union.
447// TargetAPI - target SDK API level.
448// IsFilterscript - whether or not we are compiling for Filterscript
449// IsExtern - is this type externally visible (i.e. extern global or parameter
450//                                             to an extern function)
451static bool ValidateTypeHelper(
452    slang::RSContext *Context,
453    clang::ASTContext &C,
454    const clang::Type *&T,
455    clang::NamedDecl *ND,
456    clang::SourceLocation Loc,
457    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
458    bool InCompositeType,
459    clang::RecordDecl *UnionDecl,
460    unsigned int TargetAPI,
461    bool IsFilterscript,
462    bool IsExtern) {
463  if ((T = GetCanonicalType(T)) == nullptr)
464    return true;
465
466  if (SPS.count(T))
467    return true;
468
469  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
470
471  switch (T->getTypeClass()) {
472    case clang::Type::Record: {
473      if (RSExportPrimitiveType::IsRSObjectType(T)) {
474        clang::VarDecl *VD = (ND ? llvm::dyn_cast<clang::VarDecl>(ND) : nullptr);
475        if (VD && !ValidateRSObjectInVarDecl(Context, VD, InCompositeType,
476                                             TargetAPI)) {
477          return false;
478        }
479      }
480
481      if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
482        if (!UnionDecl) {
483          return true;
484        } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
485          ReportTypeError(Context, nullptr, UnionDecl,
486              "unions containing RS object types are not allowed");
487          return false;
488        }
489      }
490
491      clang::RecordDecl *RD = nullptr;
492
493      // Check internal struct
494      if (T->isUnionType()) {
495        RD = T->getAsUnionType()->getDecl();
496        UnionDecl = RD;
497      } else if (T->isStructureType()) {
498        RD = T->getAsStructureType()->getDecl();
499      } else {
500        slangAssert(false && "Unknown type cannot be exported");
501        return false;
502      }
503
504      if (RD != nullptr) {
505        RD = RD->getDefinition();
506        if (RD == nullptr) {
507          // FIXME
508          return true;
509        }
510      }
511
512      // Fast check
513      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
514        return false;
515
516      // Insert myself into checking set
517      SPS.insert(T);
518
519      // Check all elements
520      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
521               FE = RD->field_end();
522           FI != FE;
523           FI++) {
524        const clang::FieldDecl *FD = *FI;
525        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
526        FT = GetCanonicalType(FT);
527
528        if (!ValidateTypeHelper(Context, C, FT, ND, Loc, SPS, true, UnionDecl,
529                                TargetAPI, IsFilterscript, IsExtern)) {
530          return false;
531        }
532      }
533
534      return true;
535    }
536
537    case clang::Type::Builtin: {
538      if (IsFilterscript) {
539        clang::QualType QT = T->getCanonicalTypeInternal();
540        if (QT == C.DoubleTy ||
541            QT == C.LongDoubleTy ||
542            QT == C.LongTy ||
543            QT == C.LongLongTy) {
544          if (ND) {
545            Context->ReportError(
546                Loc,
547                "Builtin types > 32 bits in size are forbidden in "
548                "Filterscript: '%0'")
549                << ND->getName();
550          } else {
551            Context->ReportError(
552                Loc,
553                "Builtin types > 32 bits in size are forbidden in "
554                "Filterscript");
555          }
556          return false;
557        }
558      }
559      break;
560    }
561
562    case clang::Type::Pointer: {
563      if (IsFilterscript) {
564        if (ND) {
565          Context->ReportError(Loc,
566                               "Pointers are forbidden in Filterscript: '%0'")
567              << ND->getName();
568          return false;
569        } else {
570          // TODO(srhines): Find a better way to handle expressions (i.e. no
571          // NamedDecl) involving pointers in FS that should be allowed.
572          // An example would be calls to library functions like
573          // rsMatrixMultiply() that take rs_matrixNxN * types.
574        }
575      }
576
577      // Forbid pointers in structures that are externally visible.
578      if (InCompositeType && IsExtern) {
579        if (ND) {
580          Context->ReportError(Loc,
581              "structures containing pointers cannot be used as the type of "
582              "an exported global variable or the parameter to an exported "
583              "function: '%0'")
584            << ND->getName();
585        } else {
586          Context->ReportError(Loc,
587              "structures containing pointers cannot be used as the type of "
588              "an exported global variable or the parameter to an exported "
589              "function");
590        }
591        return false;
592      }
593
594      const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
595      const clang::Type *PointeeType = GetPointeeType(PT);
596
597      return ValidateTypeHelper(Context, C, PointeeType, ND, Loc, SPS,
598                                InCompositeType, UnionDecl, TargetAPI,
599                                IsFilterscript, IsExtern);
600    }
601
602    case clang::Type::ExtVector: {
603      const clang::ExtVectorType *EVT =
604              static_cast<const clang::ExtVectorType*>(CTI);
605      const clang::Type *ElementType = GetExtVectorElementType(EVT);
606      if (TargetAPI < SLANG_ICS_TARGET_API &&
607          InCompositeType &&
608          EVT->getNumElements() == 3 &&
609          ND &&
610          ND->getFormalLinkage() == clang::ExternalLinkage) {
611        ReportTypeError(Context, ND, nullptr,
612                        "structs containing vectors of dimension 3 cannot "
613                        "be exported at this API level: '%0'");
614        return false;
615      }
616      return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
617                                UnionDecl, TargetAPI, IsFilterscript, IsExtern);
618    }
619
620    case clang::Type::ConstantArray: {
621      const clang::ConstantArrayType *CAT = static_cast<const clang::ConstantArrayType*>(CTI);
622      const clang::Type *ElementType = GetConstantArrayElementType(CAT);
623      return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
624                                UnionDecl, TargetAPI, IsFilterscript, IsExtern);
625    }
626
627    default: {
628      break;
629    }
630  }
631
632  return true;
633}
634
635}  // namespace
636
637std::string CreateDummyName(const char *type, const std::string &name) {
638  std::stringstream S;
639  S << "<" << type;
640  if (!name.empty()) {
641    S << ":" << name;
642  }
643  S << ">";
644  return S.str();
645}
646
647/****************************** RSExportType ******************************/
648bool RSExportType::NormalizeType(const clang::Type *&T,
649                                 llvm::StringRef &TypeName,
650                                 RSContext *Context,
651                                 const clang::VarDecl *VD) {
652  if ((T = TypeExportable(T, Context, VD)) == nullptr) {
653    return false;
654  }
655  // Get type name
656  TypeName = RSExportType::GetTypeName(T);
657  if (Context && TypeName.empty()) {
658    if (VD) {
659      Context->ReportError(VD->getLocation(),
660                           "anonymous types cannot be exported");
661    } else {
662      Context->ReportError("anonymous types cannot be exported");
663    }
664    return false;
665  }
666
667  return true;
668}
669
670bool RSExportType::ValidateType(slang::RSContext *Context, clang::ASTContext &C,
671                                clang::QualType QT, clang::NamedDecl *ND,
672                                clang::SourceLocation Loc,
673                                unsigned int TargetAPI, bool IsFilterscript,
674                                bool IsExtern) {
675  const clang::Type *T = QT.getTypePtr();
676  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
677      llvm::SmallPtrSet<const clang::Type*, 8>();
678
679  // If this is an externally visible variable declaration, we check if the
680  // type is able to be exported first.
681  if (auto VD = llvm::dyn_cast_or_null<clang::VarDecl>(ND)) {
682    if (VD->getFormalLinkage() == clang::ExternalLinkage) {
683      if (!TypeExportable(T, Context, VD)) {
684        return false;
685      }
686    }
687  }
688  return ValidateTypeHelper(Context, C, T, ND, Loc, SPS, false, nullptr, TargetAPI,
689                            IsFilterscript, IsExtern);
690}
691
692bool RSExportType::ValidateVarDecl(slang::RSContext *Context,
693                                   clang::VarDecl *VD, unsigned int TargetAPI,
694                                   bool IsFilterscript) {
695  return ValidateType(Context, VD->getASTContext(), VD->getType(), VD,
696                      VD->getLocation(), TargetAPI, IsFilterscript,
697                      (VD->getFormalLinkage() == clang::ExternalLinkage));
698}
699
700const clang::Type
701*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
702  if (DD) {
703    clang::QualType T = DD->getType();
704
705    if (T.isNull())
706      return nullptr;
707    else
708      return T.getTypePtr();
709  }
710  return nullptr;
711}
712
713llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
714  T = GetCanonicalType(T);
715  if (T == nullptr)
716    return llvm::StringRef();
717
718  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
719
720  switch (T->getTypeClass()) {
721    case clang::Type::Builtin: {
722      const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
723      BuiltinInfo *info = FindBuiltinType(BT->getKind());
724      if (info != nullptr) {
725        return info->cname[0];
726      }
727      slangAssert(false && "Unknown data type of the builtin");
728      break;
729    }
730    case clang::Type::Record: {
731      clang::RecordDecl *RD;
732      if (T->isStructureType()) {
733        RD = T->getAsStructureType()->getDecl();
734      } else {
735        break;
736      }
737
738      llvm::StringRef Name = RD->getName();
739      if (Name.empty()) {
740        if (RD->getTypedefNameForAnonDecl() != nullptr) {
741          Name = RD->getTypedefNameForAnonDecl()->getName();
742        }
743
744        if (Name.empty()) {
745          // Try to find a name from redeclaration (i.e. typedef)
746          for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
747                   RE = RD->redecls_end();
748               RI != RE;
749               RI++) {
750            slangAssert(*RI != nullptr && "cannot be NULL object");
751
752            Name = (*RI)->getName();
753            if (!Name.empty())
754              break;
755          }
756        }
757      }
758      return Name;
759    }
760    case clang::Type::Pointer: {
761      // "*" plus pointee name
762      const clang::PointerType *P = static_cast<const clang::PointerType*>(CTI);
763      const clang::Type *PT = GetPointeeType(P);
764      llvm::StringRef PointeeName;
765      if (NormalizeType(PT, PointeeName, nullptr, nullptr)) {
766        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
767        Name[0] = '*';
768        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
769        Name[PointeeName.size() + 1] = '\0';
770        return Name;
771      }
772      break;
773    }
774    case clang::Type::ExtVector: {
775      const clang::ExtVectorType *EVT =
776              static_cast<const clang::ExtVectorType*>(CTI);
777      return RSExportVectorType::GetTypeName(EVT);
778      break;
779    }
780    case clang::Type::ConstantArray : {
781      // Construct name for a constant array is too complicated.
782      return "<ConstantArray>";
783    }
784    default: {
785      break;
786    }
787  }
788
789  return llvm::StringRef();
790}
791
792
793RSExportType *RSExportType::Create(RSContext *Context,
794                                   const clang::Type *T,
795                                   const llvm::StringRef &TypeName) {
796  // Lookup the context to see whether the type was processed before.
797  // Newly created RSExportType will insert into context
798  // in RSExportType::RSExportType()
799  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
800
801  if (ETI != Context->export_types_end())
802    return ETI->second;
803
804  const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
805
806  RSExportType *ET = nullptr;
807  switch (T->getTypeClass()) {
808    case clang::Type::Record: {
809      DataType dt = RSExportPrimitiveType::GetRSSpecificType(TypeName);
810      switch (dt) {
811        case DataTypeUnknown: {
812          // User-defined types
813          ET = RSExportRecordType::Create(Context,
814                                          T->getAsStructureType(),
815                                          TypeName);
816          break;
817        }
818        case DataTypeRSMatrix2x2: {
819          // 2 x 2 Matrix type
820          ET = RSExportMatrixType::Create(Context,
821                                          T->getAsStructureType(),
822                                          TypeName,
823                                          2);
824          break;
825        }
826        case DataTypeRSMatrix3x3: {
827          // 3 x 3 Matrix type
828          ET = RSExportMatrixType::Create(Context,
829                                          T->getAsStructureType(),
830                                          TypeName,
831                                          3);
832          break;
833        }
834        case DataTypeRSMatrix4x4: {
835          // 4 x 4 Matrix type
836          ET = RSExportMatrixType::Create(Context,
837                                          T->getAsStructureType(),
838                                          TypeName,
839                                          4);
840          break;
841        }
842        default: {
843          // Others are primitive types
844          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
845          break;
846        }
847      }
848      break;
849    }
850    case clang::Type::Builtin: {
851      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
852      break;
853    }
854    case clang::Type::Pointer: {
855      ET = RSExportPointerType::Create(Context,
856                                       static_cast<const clang::PointerType*>(CTI),
857                                       TypeName);
858      // FIXME: free the name (allocated in RSExportType::GetTypeName)
859      delete [] TypeName.data();
860      break;
861    }
862    case clang::Type::ExtVector: {
863      ET = RSExportVectorType::Create(Context,
864                                      static_cast<const clang::ExtVectorType*>(CTI),
865                                      TypeName);
866      break;
867    }
868    case clang::Type::ConstantArray: {
869      ET = RSExportConstantArrayType::Create(
870              Context,
871              static_cast<const clang::ConstantArrayType*>(CTI));
872      break;
873    }
874    default: {
875      Context->ReportError("unknown type cannot be exported: '%0'")
876          << T->getTypeClassName();
877      break;
878    }
879  }
880
881  return ET;
882}
883
884RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
885  llvm::StringRef TypeName;
886  if (NormalizeType(T, TypeName, Context, nullptr)) {
887    return Create(Context, T, TypeName);
888  } else {
889    return nullptr;
890  }
891}
892
893RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
894                                           const clang::VarDecl *VD) {
895  return RSExportType::Create(Context, GetTypeOfDecl(VD));
896}
897
898size_t RSExportType::getStoreSize() const {
899  return getRSContext()->getDataLayout()->getTypeStoreSize(getLLVMType());
900}
901
902size_t RSExportType::getAllocSize() const {
903    return getRSContext()->getDataLayout()->getTypeAllocSize(getLLVMType());
904}
905
906RSExportType::RSExportType(RSContext *Context,
907                           ExportClass Class,
908                           const llvm::StringRef &Name)
909    : RSExportable(Context, RSExportable::EX_TYPE),
910      mClass(Class),
911      // Make a copy on Name since memory stored @Name is either allocated in
912      // ASTContext or allocated in GetTypeName which will be destroyed later.
913      mName(Name.data(), Name.size()),
914      mLLVMType(nullptr) {
915  // Don't cache the type whose name start with '<'. Those type failed to
916  // get their name since constructing their name in GetTypeName() requiring
917  // complicated work.
918  if (!IsDummyName(Name)) {
919    // TODO(zonr): Need to check whether the insertion is successful or not.
920    Context->insertExportType(llvm::StringRef(Name), this);
921  }
922
923}
924
925bool RSExportType::keep() {
926  if (!RSExportable::keep())
927    return false;
928  // Invalidate converted LLVM type.
929  mLLVMType = nullptr;
930  return true;
931}
932
933bool RSExportType::equals(const RSExportable *E) const {
934  CHECK_PARENT_EQUALITY(RSExportable, E);
935  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
936}
937
938RSExportType::~RSExportType() {
939}
940
941/************************** RSExportPrimitiveType **************************/
942llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
943RSExportPrimitiveType::RSSpecificTypeMap;
944
945bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
946  if ((T != nullptr) && (T->getTypeClass() == clang::Type::Builtin))
947    return true;
948  else
949    return false;
950}
951
952DataType
953RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
954  if (TypeName.empty())
955    return DataTypeUnknown;
956
957  if (RSSpecificTypeMap->empty()) {
958    for (int i = 0; i < MatrixAndObjectDataTypesCount; i++) {
959      (*RSSpecificTypeMap)[MatrixAndObjectDataTypes[i].name] =
960          MatrixAndObjectDataTypes[i].dataType;
961    }
962  }
963
964  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
965  if (I == RSSpecificTypeMap->end())
966    return DataTypeUnknown;
967  else
968    return I->getValue();
969}
970
971DataType RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
972  T = GetCanonicalType(T);
973  if ((T == nullptr) || (T->getTypeClass() != clang::Type::Record))
974    return DataTypeUnknown;
975
976  return GetRSSpecificType( RSExportType::GetTypeName(T) );
977}
978
979bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
980    if (DT < 0 || DT >= DataTypeMax) {
981        return false;
982    }
983    return gReflectionTypes[DT].category == MatrixDataType;
984}
985
986bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
987    if (DT < 0 || DT >= DataTypeMax) {
988        return false;
989    }
990    return gReflectionTypes[DT].category == ObjectDataType;
991}
992
993bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
994  bool RSObjectTypeSeen = false;
995  while (T && T->isArrayType()) {
996    T = T->getArrayElementTypeNoTypeQual();
997  }
998
999  const clang::RecordType *RT = T->getAsStructureType();
1000  if (!RT) {
1001    return false;
1002  }
1003
1004  const clang::RecordDecl *RD = RT->getDecl();
1005  if (RD) {
1006    RD = RD->getDefinition();
1007  }
1008  if (!RD) {
1009    return false;
1010  }
1011
1012  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1013         FE = RD->field_end();
1014       FI != FE;
1015       FI++) {
1016    // We just look through all field declarations to see if we find a
1017    // declaration for an RS object type (or an array of one).
1018    const clang::FieldDecl *FD = *FI;
1019    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1020    while (FT && FT->isArrayType()) {
1021      FT = FT->getArrayElementTypeNoTypeQual();
1022    }
1023
1024    DataType DT = GetRSSpecificType(FT);
1025    if (IsRSObjectType(DT)) {
1026      // RS object types definitely need to be zero-initialized
1027      RSObjectTypeSeen = true;
1028    } else {
1029      switch (DT) {
1030        case DataTypeRSMatrix2x2:
1031        case DataTypeRSMatrix3x3:
1032        case DataTypeRSMatrix4x4:
1033          // Matrix types should get zero-initialized as well
1034          RSObjectTypeSeen = true;
1035          break;
1036        default:
1037          // Ignore all other primitive types
1038          break;
1039      }
1040      while (FT && FT->isArrayType()) {
1041        FT = FT->getArrayElementTypeNoTypeQual();
1042      }
1043      if (FT->isStructureType()) {
1044        // Recursively handle structs of structs (even though these can't
1045        // be exported, it is possible for a user to have them internally).
1046        RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
1047      }
1048    }
1049  }
1050
1051  return RSObjectTypeSeen;
1052}
1053
1054size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
1055  int type = EPT->getType();
1056  slangAssert((type > DataTypeUnknown && type < DataTypeMax) &&
1057              "RSExportPrimitiveType::GetSizeInBits : unknown data type");
1058  // All RS object types are 256 bits in 64-bit RS.
1059  if (EPT->isRSObjectType() && EPT->getRSContext()->is64Bit()) {
1060    return 256;
1061  }
1062  return gReflectionTypes[type].size_in_bits;
1063}
1064
1065DataType
1066RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
1067  if (T == nullptr)
1068    return DataTypeUnknown;
1069
1070  switch (T->getTypeClass()) {
1071    case clang::Type::Builtin: {
1072      const clang::BuiltinType *BT =
1073              static_cast<const clang::BuiltinType*>(T->getCanonicalTypeInternal().getTypePtr());
1074      BuiltinInfo *info = FindBuiltinType(BT->getKind());
1075      if (info != nullptr) {
1076        return info->type;
1077      }
1078      // The size of type WChar depend on platform so we abandon the support
1079      // to them.
1080      Context->ReportError("built-in type cannot be exported: '%0'")
1081          << T->getTypeClassName();
1082      break;
1083    }
1084    case clang::Type::Record: {
1085      // must be RS object type
1086      return RSExportPrimitiveType::GetRSSpecificType(T);
1087    }
1088    default: {
1089      Context->ReportError("primitive type cannot be exported: '%0'")
1090          << T->getTypeClassName();
1091      break;
1092    }
1093  }
1094
1095  return DataTypeUnknown;
1096}
1097
1098RSExportPrimitiveType
1099*RSExportPrimitiveType::Create(RSContext *Context,
1100                               const clang::Type *T,
1101                               const llvm::StringRef &TypeName,
1102                               bool Normalized) {
1103  DataType DT = GetDataType(Context, T);
1104
1105  if ((DT == DataTypeUnknown) || TypeName.empty())
1106    return nullptr;
1107  else
1108    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
1109                                     DT, Normalized);
1110}
1111
1112RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
1113                                                     const clang::Type *T) {
1114  llvm::StringRef TypeName;
1115  if (RSExportType::NormalizeType(T, TypeName, Context, nullptr)
1116      && IsPrimitiveType(T)) {
1117    return Create(Context, T, TypeName);
1118  } else {
1119    return nullptr;
1120  }
1121}
1122
1123llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
1124  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1125
1126  if (isRSObjectType()) {
1127    // struct {
1128    //   int *p;
1129    // } __attribute__((packed, aligned(pointer_size)))
1130    //
1131    // which is
1132    //
1133    // <{ [1 x i32] }> in LLVM
1134    //
1135    std::vector<llvm::Type *> Elements;
1136    if (getRSContext()->is64Bit()) {
1137      // 64-bit path
1138      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt64Ty(C), 4));
1139      return llvm::StructType::get(C, Elements, true);
1140    } else {
1141      // 32-bit legacy path
1142      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
1143      return llvm::StructType::get(C, Elements, true);
1144    }
1145  }
1146
1147  switch (mType) {
1148    case DataTypeFloat16: {
1149      return llvm::Type::getHalfTy(C);
1150      break;
1151    }
1152    case DataTypeFloat32: {
1153      return llvm::Type::getFloatTy(C);
1154      break;
1155    }
1156    case DataTypeFloat64: {
1157      return llvm::Type::getDoubleTy(C);
1158      break;
1159    }
1160    case DataTypeBoolean: {
1161      return llvm::Type::getInt1Ty(C);
1162      break;
1163    }
1164    case DataTypeSigned8:
1165    case DataTypeUnsigned8: {
1166      return llvm::Type::getInt8Ty(C);
1167      break;
1168    }
1169    case DataTypeSigned16:
1170    case DataTypeUnsigned16:
1171    case DataTypeUnsigned565:
1172    case DataTypeUnsigned5551:
1173    case DataTypeUnsigned4444: {
1174      return llvm::Type::getInt16Ty(C);
1175      break;
1176    }
1177    case DataTypeSigned32:
1178    case DataTypeUnsigned32: {
1179      return llvm::Type::getInt32Ty(C);
1180      break;
1181    }
1182    case DataTypeSigned64:
1183    case DataTypeUnsigned64: {
1184      return llvm::Type::getInt64Ty(C);
1185      break;
1186    }
1187    default: {
1188      slangAssert(false && "Unknown data type");
1189    }
1190  }
1191
1192  return nullptr;
1193}
1194
1195bool RSExportPrimitiveType::equals(const RSExportable *E) const {
1196  CHECK_PARENT_EQUALITY(RSExportType, E);
1197  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
1198}
1199
1200RSReflectionType *RSExportPrimitiveType::getRSReflectionType(DataType DT) {
1201  if (DT > DataTypeUnknown && DT < DataTypeMax) {
1202    return &gReflectionTypes[DT];
1203  } else {
1204    return nullptr;
1205  }
1206}
1207
1208/**************************** RSExportPointerType ****************************/
1209
1210RSExportPointerType
1211*RSExportPointerType::Create(RSContext *Context,
1212                             const clang::PointerType *PT,
1213                             const llvm::StringRef &TypeName) {
1214  const clang::Type *PointeeType = GetPointeeType(PT);
1215  const RSExportType *PointeeET;
1216
1217  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
1218    PointeeET = RSExportType::Create(Context, PointeeType);
1219  } else {
1220    // Double or higher dimension of pointer, export as int*
1221    PointeeET = RSExportPrimitiveType::Create(Context,
1222                    Context->getASTContext().IntTy.getTypePtr());
1223  }
1224
1225  if (PointeeET == nullptr) {
1226    // Error diagnostic is emitted for corresponding pointee type
1227    return nullptr;
1228  }
1229
1230  return new RSExportPointerType(Context, TypeName, PointeeET);
1231}
1232
1233llvm::Type *RSExportPointerType::convertToLLVMType() const {
1234  llvm::Type *PointeeType = mPointeeType->getLLVMType();
1235  return llvm::PointerType::getUnqual(PointeeType);
1236}
1237
1238bool RSExportPointerType::keep() {
1239  if (!RSExportType::keep())
1240    return false;
1241  const_cast<RSExportType*>(mPointeeType)->keep();
1242  return true;
1243}
1244
1245bool RSExportPointerType::equals(const RSExportable *E) const {
1246  CHECK_PARENT_EQUALITY(RSExportType, E);
1247  return (static_cast<const RSExportPointerType*>(E)
1248              ->getPointeeType()->equals(getPointeeType()));
1249}
1250
1251/***************************** RSExportVectorType *****************************/
1252llvm::StringRef
1253RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
1254  const clang::Type *ElementType = GetExtVectorElementType(EVT);
1255  llvm::StringRef name;
1256
1257  if ((ElementType->getTypeClass() != clang::Type::Builtin))
1258    return name;
1259
1260  const clang::BuiltinType *BT =
1261          static_cast<const clang::BuiltinType*>(
1262              ElementType->getCanonicalTypeInternal().getTypePtr());
1263
1264  if ((EVT->getNumElements() < 1) ||
1265      (EVT->getNumElements() > 4))
1266    return name;
1267
1268  BuiltinInfo *info = FindBuiltinType(BT->getKind());
1269  if (info != nullptr) {
1270    int I = EVT->getNumElements() - 1;
1271    if (I < kMaxVectorSize) {
1272      name = info->cname[I];
1273    } else {
1274      slangAssert(false && "Max vector is 4");
1275    }
1276  }
1277  return name;
1278}
1279
1280RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
1281                                               const clang::ExtVectorType *EVT,
1282                                               const llvm::StringRef &TypeName,
1283                                               bool Normalized) {
1284  slangAssert(EVT != nullptr && EVT->getTypeClass() == clang::Type::ExtVector);
1285
1286  const clang::Type *ElementType = GetExtVectorElementType(EVT);
1287  DataType DT = RSExportPrimitiveType::GetDataType(Context, ElementType);
1288
1289  if (DT != DataTypeUnknown)
1290    return new RSExportVectorType(Context,
1291                                  TypeName,
1292                                  DT,
1293                                  Normalized,
1294                                  EVT->getNumElements());
1295  else
1296    return nullptr;
1297}
1298
1299llvm::Type *RSExportVectorType::convertToLLVMType() const {
1300  llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
1301  return llvm::VectorType::get(ElementType, getNumElement());
1302}
1303
1304bool RSExportVectorType::equals(const RSExportable *E) const {
1305  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
1306  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
1307              == getNumElement());
1308}
1309
1310/***************************** RSExportMatrixType *****************************/
1311RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
1312                                               const clang::RecordType *RT,
1313                                               const llvm::StringRef &TypeName,
1314                                               unsigned Dim) {
1315  slangAssert((RT != nullptr) && (RT->getTypeClass() == clang::Type::Record));
1316  slangAssert((Dim > 1) && "Invalid dimension of matrix");
1317
1318  // Check whether the struct rs_matrix is in our expected form (but assume it's
1319  // correct if we're not sure whether it's correct or not)
1320  const clang::RecordDecl* RD = RT->getDecl();
1321  RD = RD->getDefinition();
1322  if (RD != nullptr) {
1323    // Find definition, perform further examination
1324    if (RD->field_empty()) {
1325      Context->ReportError(
1326          RD->getLocation(),
1327          "invalid matrix struct: must have 1 field for saving values: '%0'")
1328          << RD->getName();
1329      return nullptr;
1330    }
1331
1332    clang::RecordDecl::field_iterator FIT = RD->field_begin();
1333    const clang::FieldDecl *FD = *FIT;
1334    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1335    if ((FT == nullptr) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
1336      Context->ReportError(RD->getLocation(),
1337                           "invalid matrix struct: first field should"
1338                           " be an array with constant size: '%0'")
1339          << RD->getName();
1340      return nullptr;
1341    }
1342    const clang::ConstantArrayType *CAT =
1343      static_cast<const clang::ConstantArrayType *>(FT);
1344    const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1345    if ((ElementType == nullptr) ||
1346        (ElementType->getTypeClass() != clang::Type::Builtin) ||
1347        (static_cast<const clang::BuiltinType *>(ElementType)->getKind() !=
1348         clang::BuiltinType::Float)) {
1349      Context->ReportError(RD->getLocation(),
1350                           "invalid matrix struct: first field "
1351                           "should be a float array: '%0'")
1352          << RD->getName();
1353      return nullptr;
1354    }
1355
1356    if (CAT->getSize() != Dim * Dim) {
1357      Context->ReportError(RD->getLocation(),
1358                           "invalid matrix struct: first field "
1359                           "should be an array with size %0: '%1'")
1360          << (Dim * Dim) << (RD->getName());
1361      return nullptr;
1362    }
1363
1364    FIT++;
1365    if (FIT != RD->field_end()) {
1366      Context->ReportError(RD->getLocation(),
1367                           "invalid matrix struct: must have "
1368                           "exactly 1 field: '%0'")
1369          << RD->getName();
1370      return nullptr;
1371    }
1372  }
1373
1374  return new RSExportMatrixType(Context, TypeName, Dim);
1375}
1376
1377llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1378  // Construct LLVM type:
1379  // struct {
1380  //  float X[mDim * mDim];
1381  // }
1382
1383  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1384  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1385                                            mDim * mDim);
1386  return llvm::StructType::get(C, X, false);
1387}
1388
1389bool RSExportMatrixType::equals(const RSExportable *E) const {
1390  CHECK_PARENT_EQUALITY(RSExportType, E);
1391  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1392}
1393
1394/************************* RSExportConstantArrayType *************************/
1395RSExportConstantArrayType
1396*RSExportConstantArrayType::Create(RSContext *Context,
1397                                   const clang::ConstantArrayType *CAT) {
1398  slangAssert(CAT != nullptr && CAT->getTypeClass() == clang::Type::ConstantArray);
1399
1400  slangAssert((CAT->getSize().getActiveBits() < 32) && "array too large");
1401
1402  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1403  slangAssert((Size > 0) && "Constant array should have size greater than 0");
1404
1405  const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1406  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
1407
1408  if (ElementET == nullptr) {
1409    return nullptr;
1410  }
1411
1412  return new RSExportConstantArrayType(Context,
1413                                       ElementET,
1414                                       Size);
1415}
1416
1417llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1418  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
1419}
1420
1421bool RSExportConstantArrayType::keep() {
1422  if (!RSExportType::keep())
1423    return false;
1424  const_cast<RSExportType*>(mElementType)->keep();
1425  return true;
1426}
1427
1428bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1429  CHECK_PARENT_EQUALITY(RSExportType, E);
1430  const RSExportConstantArrayType *RHS =
1431      static_cast<const RSExportConstantArrayType*>(E);
1432  return ((getSize() == RHS->getSize()) &&
1433          (getElementType()->equals(RHS->getElementType())));
1434}
1435
1436/**************************** RSExportRecordType ****************************/
1437RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1438                                               const clang::RecordType *RT,
1439                                               const llvm::StringRef &TypeName,
1440                                               bool mIsArtificial) {
1441  slangAssert(RT != nullptr && RT->getTypeClass() == clang::Type::Record);
1442
1443  const clang::RecordDecl *RD = RT->getDecl();
1444  slangAssert(RD->isStruct());
1445
1446  RD = RD->getDefinition();
1447  if (RD == nullptr) {
1448    slangAssert(false && "struct is not defined in this module");
1449    return nullptr;
1450  }
1451
1452  // Struct layout construct by clang. We rely on this for obtaining the
1453  // alloc size of a struct and offset of every field in that struct.
1454  const clang::ASTRecordLayout *RL =
1455      &Context->getASTContext().getASTRecordLayout(RD);
1456  slangAssert((RL != nullptr) &&
1457      "Failed to retrieve the struct layout from Clang.");
1458
1459  RSExportRecordType *ERT =
1460      new RSExportRecordType(Context,
1461                             TypeName,
1462                             RD->hasAttr<clang::PackedAttr>(),
1463                             mIsArtificial,
1464                             RL->getDataSize().getQuantity(),
1465                             RL->getSize().getQuantity());
1466  unsigned int Index = 0;
1467
1468  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1469           FE = RD->field_end();
1470       FI != FE;
1471       FI++, Index++) {
1472
1473    // FIXME: All fields should be primitive type
1474    slangAssert(FI->getKind() == clang::Decl::Field);
1475    clang::FieldDecl *FD = *FI;
1476
1477    if (FD->isBitField()) {
1478      return nullptr;
1479    }
1480
1481    // Type
1482    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1483
1484    if (ET != nullptr) {
1485      ERT->mFields.push_back(
1486          new Field(ET, FD->getName(), ERT,
1487                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1488    } else {
1489      Context->ReportError(RD->getLocation(),
1490                           "field type cannot be exported: '%0.%1'")
1491          << RD->getName() << FD->getName();
1492      return nullptr;
1493    }
1494  }
1495
1496  return ERT;
1497}
1498
1499llvm::Type *RSExportRecordType::convertToLLVMType() const {
1500  // Create an opaque type since struct may reference itself recursively.
1501
1502  // TODO(sliao): LLVM took out the OpaqueType. Any other to migrate to?
1503  std::vector<llvm::Type*> FieldTypes;
1504
1505  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1506       FI != FE;
1507       FI++) {
1508    const Field *F = *FI;
1509    const RSExportType *FET = F->getType();
1510
1511    FieldTypes.push_back(FET->getLLVMType());
1512  }
1513
1514  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1515                                               FieldTypes,
1516                                               mIsPacked);
1517  if (ST != nullptr) {
1518    return ST;
1519  } else {
1520    return nullptr;
1521  }
1522}
1523
1524bool RSExportRecordType::keep() {
1525  if (!RSExportType::keep())
1526    return false;
1527  for (std::list<const Field*>::iterator I = mFields.begin(),
1528          E = mFields.end();
1529       I != E;
1530       I++) {
1531    const_cast<RSExportType*>((*I)->getType())->keep();
1532  }
1533  return true;
1534}
1535
1536bool RSExportRecordType::equals(const RSExportable *E) const {
1537  CHECK_PARENT_EQUALITY(RSExportType, E);
1538
1539  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1540
1541  if (ERT->getFields().size() != getFields().size())
1542    return false;
1543
1544  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1545
1546  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1547    if (!(*AI)->getType()->equals((*BI)->getType()))
1548      return false;
1549    AI++;
1550    BI++;
1551  }
1552
1553  return true;
1554}
1555
1556void RSExportType::convertToRTD(RSReflectionTypeData *rtd) const {
1557    memset(rtd, 0, sizeof(*rtd));
1558    rtd->vecSize = 1;
1559
1560    switch(getClass()) {
1561    case RSExportType::ExportClassPrimitive: {
1562            const RSExportPrimitiveType *EPT = static_cast<const RSExportPrimitiveType*>(this);
1563            rtd->type = RSExportPrimitiveType::getRSReflectionType(EPT);
1564            return;
1565        }
1566    case RSExportType::ExportClassPointer: {
1567            const RSExportPointerType *EPT = static_cast<const RSExportPointerType*>(this);
1568            const RSExportType *PointeeType = EPT->getPointeeType();
1569            PointeeType->convertToRTD(rtd);
1570            rtd->isPointer = true;
1571            return;
1572        }
1573    case RSExportType::ExportClassVector: {
1574            const RSExportVectorType *EVT = static_cast<const RSExportVectorType*>(this);
1575            rtd->type = EVT->getRSReflectionType(EVT);
1576            rtd->vecSize = EVT->getNumElement();
1577            return;
1578        }
1579    case RSExportType::ExportClassMatrix: {
1580            const RSExportMatrixType *EMT = static_cast<const RSExportMatrixType*>(this);
1581            unsigned Dim = EMT->getDim();
1582            slangAssert((Dim >= 2) && (Dim <= 4));
1583            rtd->type = &gReflectionTypes[15 + Dim-2];
1584            return;
1585        }
1586    case RSExportType::ExportClassConstantArray: {
1587            const RSExportConstantArrayType* CAT =
1588              static_cast<const RSExportConstantArrayType*>(this);
1589            CAT->getElementType()->convertToRTD(rtd);
1590            rtd->arraySize = CAT->getSize();
1591            return;
1592        }
1593    case RSExportType::ExportClassRecord: {
1594            slangAssert(!"RSExportType::ExportClassRecord not implemented");
1595            return;// RS_TYPE_CLASS_NAME_PREFIX + ET->getName() + ".Item";
1596        }
1597    default: {
1598            slangAssert(false && "Unknown class of type");
1599        }
1600    }
1601}
1602
1603
1604}  // namespace slang
1605