slang_rs_export_type.cpp revision d5a84f6d49d64738e4bb7c9dea7242e48acad959
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <list>
20#include <vector>
21
22#include "clang/AST/RecordLayout.h"
23
24#include "llvm/ADT/StringExtras.h"
25
26#include "llvm/DerivedTypes.h"
27
28#include "llvm/Target/TargetData.h"
29
30#include "llvm/Type.h"
31
32#include "slang_assert.h"
33#include "slang_rs_context.h"
34#include "slang_rs_export_element.h"
35#include "slang_rs_type_spec.h"
36#include "slang_version.h"
37
38#define CHECK_PARENT_EQUALITY(ParentClass, E) \
39  if (!ParentClass::equals(E))                \
40    return false;
41
42namespace slang {
43
44namespace {
45
46static RSReflectionType gReflectionTypes[] = {
47    {"FLOAT_16", "F16", 16, "half", "half", "Half", "Half", false},
48    {"FLOAT_32", "F32", 32, "float", "float", "Float", "Float", false},
49    {"FLOAT_64", "F64", 64, "double", "double", "Double", "Double",false},
50    {"SIGNED_8", "I8", 8, "int8_t", "byte", "Byte", "Byte", false},
51    {"SIGNED_16", "I16", 16, "int16_t", "short", "Short", "Short", false},
52    {"SIGNED_32", "I32", 32, "int32_t", "int", "Int", "Int", false},
53    {"SIGNED_64", "I64", 64, "int64_t", "long", "Long", "Long", false},
54    {"UNSIGNED_8", "U8", 8, "uint8_t", "short", "UByte", "Short", true},
55    {"UNSIGNED_16", "U16", 16, "uint16_t", "int", "UShort", "Int", true},
56    {"UNSIGNED_32", "U32", 32, "uint32_t", "long", "UInt", "Long", true},
57    {"UNSIGNED_64", "U64", 64, "uint64_t", "long", "ULong", "Long", false},
58
59    {"BOOLEAN", "BOOLEAN", 8, "bool", "boolean", NULL, NULL, false},
60
61    {"UNSIGNED_5_6_5", NULL, 16, NULL, NULL, NULL, NULL, false},
62    {"UNSIGNED_5_5_5_1", NULL, 16, NULL, NULL, NULL, NULL, false},
63    {"UNSIGNED_4_4_4_4", NULL, 16, NULL, NULL, NULL, NULL, false},
64
65    {"MATRIX_2X2", NULL, 4*32, "rsMatrix_2x2", "Matrix2f", NULL, NULL, false},
66    {"MATRIX_3X3", NULL, 9*32, "rsMatrix_3x3", "Matrix3f", NULL, NULL, false},
67    {"MATRIX_4X4", NULL, 16*32, "rsMatrix_4x4", "Matrix4f", NULL, NULL, false},
68
69    {"RS_ELEMENT", "ELEMENT", 32, "Element", "Element", NULL, NULL, false},
70    {"RS_TYPE", "TYPE", 32, "Type", "Type", NULL, NULL, false},
71    {"RS_ALLOCATION", "ALLOCATION", 32, "Allocation", "Allocation", NULL, NULL, false},
72    {"RS_SAMPLER", "SAMPLER", 32, "Sampler", "Sampler", NULL, NULL, false},
73    {"RS_SCRIPT", "SCRIPT", 32, "Script", "Script", NULL, NULL, false},
74    {"RS_MESH", "MESH", 32, "Mesh", "Mesh", NULL, NULL, false},
75    {"RS_PATH", "PATH", 32, "Path", "Path", NULL, NULL, false},
76    {"RS_PROGRAM_FRAGMENT", "PROGRAM_FRAGMENT", 32, "ProgramFragment", "ProgramFragment", NULL, NULL, false},
77    {"RS_PROGRAM_VERTEX", "PROGRAM_VERTEX", 32, "ProgramVertex", "ProgramVertex", NULL, NULL, false},
78    {"RS_PROGRAM_RASTER", "PROGRAM_RASTER", 32, "ProgramRaster", "ProgramRaster", NULL, NULL, false},
79    {"RS_PROGRAM_STORE", "PROGRAM_STORE", 32, "ProgramStore", "ProgramStore", NULL, NULL, false},
80    {"RS_FONT", "FONT", 32, "Font", "Font", NULL, NULL, false}
81};
82
83static const clang::Type *TypeExportableHelper(
84    const clang::Type *T,
85    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
86    clang::DiagnosticsEngine *DiagEngine,
87    const clang::VarDecl *VD,
88    const clang::RecordDecl *TopLevelRecord);
89
90static void ReportTypeError(clang::DiagnosticsEngine *DiagEngine,
91                            const clang::VarDecl *VD,
92                            const clang::RecordDecl *TopLevelRecord,
93                            const char *Message,
94                            unsigned int TargetAPI = 0) {
95  if (!DiagEngine) {
96    return;
97  }
98
99  const clang::SourceManager &SM = DiagEngine->getSourceManager();
100
101  // Attempt to use the type declaration first (if we have one).
102  // Fall back to the variable definition, if we are looking at something
103  // like an array declaration that can't be exported.
104  if (TopLevelRecord) {
105    DiagEngine->Report(
106      clang::FullSourceLoc(TopLevelRecord->getLocation(), SM),
107      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error, Message))
108      << TopLevelRecord->getName() << TargetAPI;
109  } else if (VD) {
110    DiagEngine->Report(
111      clang::FullSourceLoc(VD->getLocation(), SM),
112      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error, Message))
113      << VD->getName() << TargetAPI;
114  } else {
115    slangAssert(false && "Variables should be validated before exporting");
116  }
117}
118
119static const clang::Type *ConstantArrayTypeExportableHelper(
120    const clang::ConstantArrayType *CAT,
121    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
122    clang::DiagnosticsEngine *DiagEngine,
123    const clang::VarDecl *VD,
124    const clang::RecordDecl *TopLevelRecord) {
125  // Check element type
126  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
127  if (ElementType->isArrayType()) {
128    ReportTypeError(DiagEngine, VD, TopLevelRecord,
129                    "multidimensional arrays cannot be exported: '%0'");
130    return NULL;
131  } else if (ElementType->isExtVectorType()) {
132    const clang::ExtVectorType *EVT =
133        static_cast<const clang::ExtVectorType*>(ElementType);
134    unsigned numElements = EVT->getNumElements();
135
136    const clang::Type *BaseElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
137    if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
138      ReportTypeError(DiagEngine, VD, TopLevelRecord,
139        "vectors of non-primitive types cannot be exported: '%0'");
140      return NULL;
141    }
142
143    if (numElements == 3 && CAT->getSize() != 1) {
144      ReportTypeError(DiagEngine, VD, TopLevelRecord,
145        "arrays of width 3 vector types cannot be exported: '%0'");
146      return NULL;
147    }
148  }
149
150  if (TypeExportableHelper(ElementType, SPS, DiagEngine, VD,
151                           TopLevelRecord) == NULL) {
152    return NULL;
153  } else {
154    return CAT;
155  }
156}
157
158static const clang::Type *TypeExportableHelper(
159    clang::Type const *T,
160    llvm::SmallPtrSet<clang::Type const *, 8> &SPS,
161    clang::DiagnosticsEngine *DiagEngine,
162    clang::VarDecl const *VD,
163    clang::RecordDecl const *TopLevelRecord) {
164  // Normalize first
165  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
166    return NULL;
167
168  if (SPS.count(T))
169    return T;
170
171  switch (T->getTypeClass()) {
172    case clang::Type::Builtin: {
173      const clang::BuiltinType *BT =
174        UNSAFE_CAST_TYPE(const clang::BuiltinType, T);
175
176      switch (BT->getKind()) {
177#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
178        case builtin_type:
179#include "RSClangBuiltinEnums.inc"
180          return T;
181        default: {
182          return NULL;
183        }
184      }
185    }
186    case clang::Type::Record: {
187      if (RSExportPrimitiveType::GetRSSpecificType(T) !=
188          RSExportPrimitiveType::DataTypeUnknown) {
189        return T;  // RS object type, no further checks are needed
190      }
191
192      // Check internal struct
193      if (T->isUnionType()) {
194        ReportTypeError(DiagEngine, VD, T->getAsUnionType()->getDecl(),
195                        "unions cannot be exported: '%0'");
196        return NULL;
197      } else if (!T->isStructureType()) {
198        slangAssert(false && "Unknown type cannot be exported");
199        return NULL;
200      }
201
202      clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
203      if (RD != NULL) {
204        RD = RD->getDefinition();
205        if (RD == NULL) {
206          ReportTypeError(DiagEngine, NULL, T->getAsStructureType()->getDecl(),
207                          "struct is not defined in this module");
208          return NULL;
209        }
210      }
211
212      if (!TopLevelRecord) {
213        TopLevelRecord = RD;
214      }
215      if (RD->getName().empty()) {
216        ReportTypeError(DiagEngine, NULL, RD,
217                        "anonymous structures cannot be exported");
218        return NULL;
219      }
220
221      // Fast check
222      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
223        return NULL;
224
225      // Insert myself into checking set
226      SPS.insert(T);
227
228      // Check all element
229      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
230               FE = RD->field_end();
231           FI != FE;
232           FI++) {
233        const clang::FieldDecl *FD = *FI;
234        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
235        FT = GET_CANONICAL_TYPE(FT);
236
237        if (!TypeExportableHelper(FT, SPS, DiagEngine, VD, TopLevelRecord)) {
238          return NULL;
239        }
240
241        // We don't support bit fields yet
242        //
243        // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
244        if (FD->isBitField()) {
245          if (DiagEngine) {
246            DiagEngine->Report(
247              clang::FullSourceLoc(FD->getLocation(),
248                                   DiagEngine->getSourceManager()),
249              DiagEngine->getCustomDiagID(
250                clang::DiagnosticsEngine::Error,
251                "bit fields are not able to be exported: '%0.%1'"))
252              << RD->getName()
253              << FD->getName();
254          }
255          return NULL;
256        }
257      }
258
259      return T;
260    }
261    case clang::Type::Pointer: {
262      if (TopLevelRecord) {
263        ReportTypeError(DiagEngine, NULL, TopLevelRecord,
264          "structures containing pointers cannot be exported: '%0'");
265        return NULL;
266      }
267
268      const clang::PointerType *PT =
269        UNSAFE_CAST_TYPE(const clang::PointerType, T);
270      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
271
272      if (PointeeType->getTypeClass() == clang::Type::Pointer)
273        return T;
274      // We don't support pointer with array-type pointee or unsupported pointee
275      // type
276      if (PointeeType->isArrayType() ||
277          (TypeExportableHelper(PointeeType, SPS, DiagEngine, VD,
278                                TopLevelRecord) == NULL))
279        return NULL;
280      else
281        return T;
282    }
283    case clang::Type::ExtVector: {
284      const clang::ExtVectorType *EVT =
285          UNSAFE_CAST_TYPE(const clang::ExtVectorType, T);
286      // Only vector with size 2, 3 and 4 are supported.
287      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
288        return NULL;
289
290      // Check base element type
291      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
292
293      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
294          (TypeExportableHelper(ElementType, SPS, DiagEngine, VD,
295                                TopLevelRecord) == NULL))
296        return NULL;
297      else
298        return T;
299    }
300    case clang::Type::ConstantArray: {
301      const clang::ConstantArrayType *CAT =
302          UNSAFE_CAST_TYPE(const clang::ConstantArrayType, T);
303
304      return ConstantArrayTypeExportableHelper(CAT, SPS, DiagEngine, VD,
305                                               TopLevelRecord);
306    }
307    default: {
308      return NULL;
309    }
310  }
311}
312
313// Return the type that can be used to create RSExportType, will always return
314// the canonical type
315// If the Type T is not exportable, this function returns NULL. DiagEngine is
316// used to generate proper Clang diagnostic messages when a
317// non-exportable type is detected. TopLevelRecord is used to capture the
318// highest struct (in the case of a nested hierarchy) for detecting other
319// types that cannot be exported (mostly pointers within a struct).
320static const clang::Type *TypeExportable(const clang::Type *T,
321                                         clang::DiagnosticsEngine *DiagEngine,
322                                         const clang::VarDecl *VD) {
323  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
324      llvm::SmallPtrSet<const clang::Type*, 8>();
325
326  return TypeExportableHelper(T, SPS, DiagEngine, VD, NULL);
327}
328
329static bool ValidateRSObjectInVarDecl(clang::VarDecl *VD,
330                                      bool InCompositeType,
331                                      unsigned int TargetAPI) {
332  if (TargetAPI < SLANG_JB_TARGET_API) {
333    // Only if we are already in a composite type (like an array or structure).
334    if (InCompositeType) {
335      // Only if we are actually exported (i.e. non-static).
336      if (VD->getLinkage() == clang::ExternalLinkage) {
337        // Only if we are not a pointer to an object.
338        const clang::Type *T = GET_CANONICAL_TYPE(VD->getType().getTypePtr());
339        if (T->getTypeClass() != clang::Type::Pointer) {
340          clang::ASTContext &C = VD->getASTContext();
341          ReportTypeError(&C.getDiagnostics(), VD, NULL,
342                          "arrays/structures containing RS object types "
343                          "cannot be exported in target API < %1: '%0'",
344                          SLANG_JB_TARGET_API);
345          return false;
346        }
347      }
348    }
349  }
350
351  return true;
352}
353
354// Helper function for ValidateVarDecl(). We do a recursive descent on the
355// type hierarchy to ensure that we can properly export/handle the
356// declaration.
357// \return true if the variable declaration is valid,
358//         false if it is invalid (along with proper diagnostics).
359//
360// VD - top-level variable declaration that we are validating.
361// T - sub-type of VD's type that we are validating.
362// SPS - set of types we have already seen/validated.
363// InCompositeType - true if we are within an outer composite type.
364// UnionDecl - set if we are in a sub-type of a union.
365// TargetAPI - target SDK API level.
366static bool ValidateVarDeclHelper(
367    clang::VarDecl *VD,
368    const clang::Type *&T,
369    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
370    bool InCompositeType,
371    clang::RecordDecl *UnionDecl,
372    unsigned int TargetAPI) {
373  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
374    return true;
375
376  if (SPS.count(T))
377    return true;
378
379  switch (T->getTypeClass()) {
380    case clang::Type::Record: {
381      if (RSExportPrimitiveType::IsRSObjectType(T)) {
382        if (!ValidateRSObjectInVarDecl(VD, InCompositeType, TargetAPI)) {
383          return false;
384        }
385      }
386
387      if (RSExportPrimitiveType::GetRSSpecificType(T) !=
388          RSExportPrimitiveType::DataTypeUnknown) {
389        if (!UnionDecl) {
390          return true;
391        } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
392          clang::ASTContext &C = VD->getASTContext();
393          ReportTypeError(&C.getDiagnostics(), VD, UnionDecl,
394              "unions containing RS object types are not allowed");
395          return false;
396        }
397      }
398
399      clang::RecordDecl *RD = NULL;
400
401      // Check internal struct
402      if (T->isUnionType()) {
403        RD = T->getAsUnionType()->getDecl();
404        UnionDecl = RD;
405      } else if (T->isStructureType()) {
406        RD = T->getAsStructureType()->getDecl();
407      } else {
408        slangAssert(false && "Unknown type cannot be exported");
409        return false;
410      }
411
412      if (RD != NULL) {
413        RD = RD->getDefinition();
414        if (RD == NULL) {
415          // FIXME
416          return true;
417        }
418      }
419
420      // Fast check
421      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
422        return false;
423
424      // Insert myself into checking set
425      SPS.insert(T);
426
427      // Check all elements
428      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
429               FE = RD->field_end();
430           FI != FE;
431           FI++) {
432        const clang::FieldDecl *FD = *FI;
433        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
434        FT = GET_CANONICAL_TYPE(FT);
435
436        if (!ValidateVarDeclHelper(VD, FT, SPS, true, UnionDecl, TargetAPI)) {
437          return false;
438        }
439      }
440
441      return true;
442    }
443
444    case clang::Type::Builtin: {
445      break;
446    }
447
448    case clang::Type::Pointer: {
449      const clang::PointerType *PT =
450        UNSAFE_CAST_TYPE(const clang::PointerType, T);
451      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
452
453      return ValidateVarDeclHelper(VD, PointeeType, SPS, InCompositeType,
454                                   UnionDecl, TargetAPI);
455    }
456
457    case clang::Type::ExtVector: {
458      const clang::ExtVectorType *EVT =
459          UNSAFE_CAST_TYPE(const clang::ExtVectorType, T);
460      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
461      return ValidateVarDeclHelper(VD, ElementType, SPS, true, UnionDecl,
462                                   TargetAPI);
463    }
464
465    case clang::Type::ConstantArray: {
466      const clang::ConstantArrayType *CAT =
467          UNSAFE_CAST_TYPE(const clang::ConstantArrayType, T);
468      const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
469      return ValidateVarDeclHelper(VD, ElementType, SPS, true, UnionDecl,
470                                   TargetAPI);
471    }
472
473    default: {
474      break;
475    }
476  }
477
478  return true;
479}
480
481}  // namespace
482
483/****************************** RSExportType ******************************/
484bool RSExportType::NormalizeType(const clang::Type *&T,
485                                 llvm::StringRef &TypeName,
486                                 clang::DiagnosticsEngine *DiagEngine,
487                                 const clang::VarDecl *VD) {
488  if ((T = TypeExportable(T, DiagEngine, VD)) == NULL) {
489    return false;
490  }
491  // Get type name
492  TypeName = RSExportType::GetTypeName(T);
493  if (TypeName.empty()) {
494    if (DiagEngine) {
495      if (VD) {
496        DiagEngine->Report(
497          clang::FullSourceLoc(VD->getLocation(),
498                               DiagEngine->getSourceManager()),
499          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
500                                      "anonymous types cannot be exported"));
501      } else {
502        DiagEngine->Report(
503          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
504                                      "anonymous types cannot be exported"));
505      }
506    }
507    return false;
508  }
509
510  return true;
511}
512
513bool RSExportType::ValidateVarDecl(clang::VarDecl *VD, unsigned int TargetAPI) {
514  const clang::Type *T = VD->getType().getTypePtr();
515  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
516      llvm::SmallPtrSet<const clang::Type*, 8>();
517
518  return ValidateVarDeclHelper(VD, T, SPS, false, NULL, TargetAPI);
519}
520
521const clang::Type
522*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
523  if (DD) {
524    clang::QualType T;
525    if (DD->getTypeSourceInfo())
526      T = DD->getTypeSourceInfo()->getType();
527    else
528      T = DD->getType();
529
530    if (T.isNull())
531      return NULL;
532    else
533      return T.getTypePtr();
534  }
535  return NULL;
536}
537
538llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
539  T = GET_CANONICAL_TYPE(T);
540  if (T == NULL)
541    return llvm::StringRef();
542
543  switch (T->getTypeClass()) {
544    case clang::Type::Builtin: {
545      const clang::BuiltinType *BT =
546        UNSAFE_CAST_TYPE(const clang::BuiltinType, T);
547
548      switch (BT->getKind()) {
549#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
550        case builtin_type:                                    \
551          return cname;                                       \
552        break;
553#include "RSClangBuiltinEnums.inc"
554        default: {
555          slangAssert(false && "Unknown data type of the builtin");
556          break;
557        }
558      }
559      break;
560    }
561    case clang::Type::Record: {
562      clang::RecordDecl *RD;
563      if (T->isStructureType()) {
564        RD = T->getAsStructureType()->getDecl();
565      } else {
566        break;
567      }
568
569      llvm::StringRef Name = RD->getName();
570      if (Name.empty()) {
571        if (RD->getTypedefNameForAnonDecl() != NULL) {
572          Name = RD->getTypedefNameForAnonDecl()->getName();
573        }
574
575        if (Name.empty()) {
576          // Try to find a name from redeclaration (i.e. typedef)
577          for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
578                   RE = RD->redecls_end();
579               RI != RE;
580               RI++) {
581            slangAssert(*RI != NULL && "cannot be NULL object");
582
583            Name = (*RI)->getName();
584            if (!Name.empty())
585              break;
586          }
587        }
588      }
589      return Name;
590    }
591    case clang::Type::Pointer: {
592      // "*" plus pointee name
593      const clang::Type *PT = GET_POINTEE_TYPE(T);
594      llvm::StringRef PointeeName;
595      if (NormalizeType(PT, PointeeName, NULL, NULL)) {
596        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
597        Name[0] = '*';
598        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
599        Name[PointeeName.size() + 1] = '\0';
600        return Name;
601      }
602      break;
603    }
604    case clang::Type::ExtVector: {
605      const clang::ExtVectorType *EVT =
606          UNSAFE_CAST_TYPE(const clang::ExtVectorType, T);
607      return RSExportVectorType::GetTypeName(EVT);
608      break;
609    }
610    case clang::Type::ConstantArray : {
611      // Construct name for a constant array is too complicated.
612      return DUMMY_TYPE_NAME_FOR_RS_CONSTANT_ARRAY_TYPE;
613    }
614    default: {
615      break;
616    }
617  }
618
619  return llvm::StringRef();
620}
621
622
623RSExportType *RSExportType::Create(RSContext *Context,
624                                   const clang::Type *T,
625                                   const llvm::StringRef &TypeName) {
626  // Lookup the context to see whether the type was processed before.
627  // Newly created RSExportType will insert into context
628  // in RSExportType::RSExportType()
629  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
630
631  if (ETI != Context->export_types_end())
632    return ETI->second;
633
634  RSExportType *ET = NULL;
635  switch (T->getTypeClass()) {
636    case clang::Type::Record: {
637      RSExportPrimitiveType::DataType dt =
638          RSExportPrimitiveType::GetRSSpecificType(TypeName);
639      switch (dt) {
640        case RSExportPrimitiveType::DataTypeUnknown: {
641          // User-defined types
642          ET = RSExportRecordType::Create(Context,
643                                          T->getAsStructureType(),
644                                          TypeName);
645          break;
646        }
647        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
648          // 2 x 2 Matrix type
649          ET = RSExportMatrixType::Create(Context,
650                                          T->getAsStructureType(),
651                                          TypeName,
652                                          2);
653          break;
654        }
655        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
656          // 3 x 3 Matrix type
657          ET = RSExportMatrixType::Create(Context,
658                                          T->getAsStructureType(),
659                                          TypeName,
660                                          3);
661          break;
662        }
663        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
664          // 4 x 4 Matrix type
665          ET = RSExportMatrixType::Create(Context,
666                                          T->getAsStructureType(),
667                                          TypeName,
668                                          4);
669          break;
670        }
671        default: {
672          // Others are primitive types
673          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
674          break;
675        }
676      }
677      break;
678    }
679    case clang::Type::Builtin: {
680      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
681      break;
682    }
683    case clang::Type::Pointer: {
684      ET = RSExportPointerType::Create(Context,
685               UNSAFE_CAST_TYPE(const clang::PointerType, T), TypeName);
686      // FIXME: free the name (allocated in RSExportType::GetTypeName)
687      delete [] TypeName.data();
688      break;
689    }
690    case clang::Type::ExtVector: {
691      ET = RSExportVectorType::Create(Context,
692               UNSAFE_CAST_TYPE(const clang::ExtVectorType, T), TypeName);
693      break;
694    }
695    case clang::Type::ConstantArray: {
696      ET = RSExportConstantArrayType::Create(
697              Context,
698              UNSAFE_CAST_TYPE(const clang::ConstantArrayType, T));
699      break;
700    }
701    default: {
702      clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
703      DiagEngine->Report(
704        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
705                                    "unknown type cannot be exported: '%0'"))
706        << T->getTypeClassName();
707      break;
708    }
709  }
710
711  return ET;
712}
713
714RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
715  llvm::StringRef TypeName;
716  if (NormalizeType(T, TypeName, Context->getDiagnostics(), NULL)) {
717    return Create(Context, T, TypeName);
718  } else {
719    return NULL;
720  }
721}
722
723RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
724                                           const clang::VarDecl *VD) {
725  return RSExportType::Create(Context, GetTypeOfDecl(VD));
726}
727
728size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
729  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
730      ET->getLLVMType());
731}
732
733size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
734  if (ET->getClass() == RSExportType::ExportClassRecord)
735    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
736  else
737    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
738        ET->getLLVMType());
739}
740
741RSExportType::RSExportType(RSContext *Context,
742                           ExportClass Class,
743                           const llvm::StringRef &Name)
744    : RSExportable(Context, RSExportable::EX_TYPE),
745      mClass(Class),
746      // Make a copy on Name since memory stored @Name is either allocated in
747      // ASTContext or allocated in GetTypeName which will be destroyed later.
748      mName(Name.data(), Name.size()),
749      mLLVMType(NULL),
750      mSpecType(NULL) {
751  // Don't cache the type whose name start with '<'. Those type failed to
752  // get their name since constructing their name in GetTypeName() requiring
753  // complicated work.
754  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
755    // TODO(zonr): Need to check whether the insertion is successful or not.
756    Context->insertExportType(llvm::StringRef(Name), this);
757  return;
758}
759
760bool RSExportType::keep() {
761  if (!RSExportable::keep())
762    return false;
763  // Invalidate converted LLVM type.
764  mLLVMType = NULL;
765  return true;
766}
767
768bool RSExportType::equals(const RSExportable *E) const {
769  CHECK_PARENT_EQUALITY(RSExportable, E);
770  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
771}
772
773RSExportType::~RSExportType() {
774  delete mSpecType;
775}
776
777/************************** RSExportPrimitiveType **************************/
778llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
779RSExportPrimitiveType::RSSpecificTypeMap;
780
781llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
782
783bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
784  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
785    return true;
786  else
787    return false;
788}
789
790RSExportPrimitiveType::DataType
791RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
792  if (TypeName.empty())
793    return DataTypeUnknown;
794
795  if (RSSpecificTypeMap->empty()) {
796#define ENUM_RS_MATRIX_TYPE(type, cname, dim)                       \
797    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
798#include "RSMatrixTypeEnums.inc"
799#define ENUM_RS_OBJECT_TYPE(type, cname)                            \
800    RSSpecificTypeMap->GetOrCreateValue(cname, DataType ## type);
801#include "RSObjectTypeEnums.inc"
802  }
803
804  RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
805  if (I == RSSpecificTypeMap->end())
806    return DataTypeUnknown;
807  else
808    return I->getValue();
809}
810
811RSExportPrimitiveType::DataType
812RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
813  T = GET_CANONICAL_TYPE(T);
814  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
815    return DataTypeUnknown;
816
817  return GetRSSpecificType( RSExportType::GetTypeName(T) );
818}
819
820bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
821  return ((DT >= FirstRSMatrixType) && (DT <= LastRSMatrixType));
822}
823
824bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
825  return ((DT >= FirstRSObjectType) && (DT <= LastRSObjectType));
826}
827
828bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
829  bool RSObjectTypeSeen = false;
830  while (T && T->isArrayType()) {
831    T = T->getArrayElementTypeNoTypeQual();
832  }
833
834  const clang::RecordType *RT = T->getAsStructureType();
835  if (!RT) {
836    return false;
837  }
838  const clang::RecordDecl *RD = RT->getDecl();
839  RD = RD->getDefinition();
840  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
841         FE = RD->field_end();
842       FI != FE;
843       FI++) {
844    // We just look through all field declarations to see if we find a
845    // declaration for an RS object type (or an array of one).
846    const clang::FieldDecl *FD = *FI;
847    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
848    while (FT && FT->isArrayType()) {
849      FT = FT->getArrayElementTypeNoTypeQual();
850    }
851
852    RSExportPrimitiveType::DataType DT = GetRSSpecificType(FT);
853    if (IsRSObjectType(DT)) {
854      // RS object types definitely need to be zero-initialized
855      RSObjectTypeSeen = true;
856    } else {
857      switch (DT) {
858        case RSExportPrimitiveType::DataTypeRSMatrix2x2:
859        case RSExportPrimitiveType::DataTypeRSMatrix3x3:
860        case RSExportPrimitiveType::DataTypeRSMatrix4x4:
861          // Matrix types should get zero-initialized as well
862          RSObjectTypeSeen = true;
863          break;
864        default:
865          // Ignore all other primitive types
866          break;
867      }
868      while (FT && FT->isArrayType()) {
869        FT = FT->getArrayElementTypeNoTypeQual();
870      }
871      if (FT->isStructureType()) {
872        // Recursively handle structs of structs (even though these can't
873        // be exported, it is possible for a user to have them internally).
874        RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
875      }
876    }
877  }
878
879  return RSObjectTypeSeen;
880}
881
882const size_t RSExportPrimitiveType::SizeOfDataTypeInBits[] = {
883#define ENUM_RS_DATA_TYPE(type, cname, bits)  \
884  bits,
885#include "RSDataTypeEnums.inc"
886  0   // DataTypeMax
887};
888
889size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
890  slangAssert(((EPT->getType() > DataTypeUnknown) &&
891               (EPT->getType() < DataTypeMax)) &&
892              "RSExportPrimitiveType::GetSizeInBits : unknown data type");
893  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
894}
895
896RSExportPrimitiveType::DataType
897RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
898  if (T == NULL)
899    return DataTypeUnknown;
900
901  switch (T->getTypeClass()) {
902    case clang::Type::Builtin: {
903      const clang::BuiltinType *BT =
904        UNSAFE_CAST_TYPE(const clang::BuiltinType, T);
905      switch (BT->getKind()) {
906#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
907        case builtin_type: {                                  \
908          return DataType ## type;                            \
909        }
910#include "RSClangBuiltinEnums.inc"
911        // The size of type WChar depend on platform so we abandon the support
912        // to them.
913        default: {
914          clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
915          DiagEngine->Report(
916            DiagEngine->getCustomDiagID(
917              clang::DiagnosticsEngine::Error,
918              "built-in type cannot be exported: '%0'"))
919            << T->getTypeClassName();
920          break;
921        }
922      }
923      break;
924    }
925    case clang::Type::Record: {
926      // must be RS object type
927      return RSExportPrimitiveType::GetRSSpecificType(T);
928    }
929    default: {
930      clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
931      DiagEngine->Report(
932        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
933                                    "primitive type cannot be exported: '%0'"))
934          << T->getTypeClassName();
935      break;
936    }
937  }
938
939  return DataTypeUnknown;
940}
941
942RSExportPrimitiveType
943*RSExportPrimitiveType::Create(RSContext *Context,
944                               const clang::Type *T,
945                               const llvm::StringRef &TypeName,
946                               bool Normalized) {
947  DataType DT = GetDataType(Context, T);
948
949  if ((DT == DataTypeUnknown) || TypeName.empty())
950    return NULL;
951  else
952    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
953                                     DT, Normalized);
954}
955
956RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
957                                                     const clang::Type *T) {
958  llvm::StringRef TypeName;
959  if (RSExportType::NormalizeType(T, TypeName, Context->getDiagnostics(), NULL)
960      && IsPrimitiveType(T)) {
961    return Create(Context, T, TypeName);
962  } else {
963    return NULL;
964  }
965}
966
967llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
968  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
969
970  if (isRSObjectType()) {
971    // struct {
972    //   int *p;
973    // } __attribute__((packed, aligned(pointer_size)))
974    //
975    // which is
976    //
977    // <{ [1 x i32] }> in LLVM
978    //
979    if (RSObjectLLVMType == NULL) {
980      std::vector<llvm::Type *> Elements;
981      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
982      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
983    }
984    return RSObjectLLVMType;
985  }
986
987  switch (mType) {
988    case DataTypeFloat32: {
989      return llvm::Type::getFloatTy(C);
990      break;
991    }
992    case DataTypeFloat64: {
993      return llvm::Type::getDoubleTy(C);
994      break;
995    }
996    case DataTypeBoolean: {
997      return llvm::Type::getInt1Ty(C);
998      break;
999    }
1000    case DataTypeSigned8:
1001    case DataTypeUnsigned8: {
1002      return llvm::Type::getInt8Ty(C);
1003      break;
1004    }
1005    case DataTypeSigned16:
1006    case DataTypeUnsigned16:
1007    case DataTypeUnsigned565:
1008    case DataTypeUnsigned5551:
1009    case DataTypeUnsigned4444: {
1010      return llvm::Type::getInt16Ty(C);
1011      break;
1012    }
1013    case DataTypeSigned32:
1014    case DataTypeUnsigned32: {
1015      return llvm::Type::getInt32Ty(C);
1016      break;
1017    }
1018    case DataTypeSigned64:
1019    case DataTypeUnsigned64: {
1020      return llvm::Type::getInt64Ty(C);
1021      break;
1022    }
1023    default: {
1024      slangAssert(false && "Unknown data type");
1025    }
1026  }
1027
1028  return NULL;
1029}
1030
1031union RSType *RSExportPrimitiveType::convertToSpecType() const {
1032  llvm::OwningPtr<union RSType> ST(new union RSType);
1033  RS_TYPE_SET_CLASS(ST, RS_TC_Primitive);
1034  // enum RSExportPrimitiveType::DataType is synced with enum RSDataType in
1035  // slang_rs_type_spec.h
1036  RS_PRIMITIVE_TYPE_SET_DATA_TYPE(ST, getType());
1037  return ST.take();
1038}
1039
1040bool RSExportPrimitiveType::equals(const RSExportable *E) const {
1041  CHECK_PARENT_EQUALITY(RSExportType, E);
1042  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
1043}
1044
1045RSReflectionType *RSExportPrimitiveType::getRSReflectionType(DataType DT) {
1046  if (DT > DataTypeUnknown && DT < DataTypeMax) {
1047    return &gReflectionTypes[DT];
1048  } else {
1049    return NULL;
1050  }
1051}
1052
1053/**************************** RSExportPointerType ****************************/
1054
1055RSExportPointerType
1056*RSExportPointerType::Create(RSContext *Context,
1057                             const clang::PointerType *PT,
1058                             const llvm::StringRef &TypeName) {
1059  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
1060  const RSExportType *PointeeET;
1061
1062  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
1063    PointeeET = RSExportType::Create(Context, PointeeType);
1064  } else {
1065    // Double or higher dimension of pointer, export as int*
1066    PointeeET = RSExportPrimitiveType::Create(Context,
1067                    Context->getASTContext().IntTy.getTypePtr());
1068  }
1069
1070  if (PointeeET == NULL) {
1071    // Error diagnostic is emitted for corresponding pointee type
1072    return NULL;
1073  }
1074
1075  return new RSExportPointerType(Context, TypeName, PointeeET);
1076}
1077
1078llvm::Type *RSExportPointerType::convertToLLVMType() const {
1079  llvm::Type *PointeeType = mPointeeType->getLLVMType();
1080  return llvm::PointerType::getUnqual(PointeeType);
1081}
1082
1083union RSType *RSExportPointerType::convertToSpecType() const {
1084  llvm::OwningPtr<union RSType> ST(new union RSType);
1085
1086  RS_TYPE_SET_CLASS(ST, RS_TC_Pointer);
1087  RS_POINTER_TYPE_SET_POINTEE_TYPE(ST, getPointeeType()->getSpecType());
1088
1089  if (RS_POINTER_TYPE_GET_POINTEE_TYPE(ST) != NULL)
1090    return ST.take();
1091  else
1092    return NULL;
1093}
1094
1095bool RSExportPointerType::keep() {
1096  if (!RSExportType::keep())
1097    return false;
1098  const_cast<RSExportType*>(mPointeeType)->keep();
1099  return true;
1100}
1101
1102bool RSExportPointerType::equals(const RSExportable *E) const {
1103  CHECK_PARENT_EQUALITY(RSExportType, E);
1104  return (static_cast<const RSExportPointerType*>(E)
1105              ->getPointeeType()->equals(getPointeeType()));
1106}
1107
1108/***************************** RSExportVectorType *****************************/
1109llvm::StringRef
1110RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
1111  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
1112
1113  if ((ElementType->getTypeClass() != clang::Type::Builtin))
1114    return llvm::StringRef();
1115
1116  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(const clang::BuiltinType,
1117                                                  ElementType);
1118  if ((EVT->getNumElements() < 1) ||
1119      (EVT->getNumElements() > 4))
1120    return llvm::StringRef();
1121
1122  switch (BT->getKind()) {
1123    // Compiler is smart enough to optimize following *big if branches* since
1124    // they all become "constant comparison" after macro expansion
1125#define ENUM_SUPPORT_BUILTIN_TYPE(builtin_type, type, cname)  \
1126    case builtin_type: {                                      \
1127      const char *Name[] = { cname"2", cname"3", cname"4" };  \
1128      return Name[EVT->getNumElements() - 2];                 \
1129      break;                                                  \
1130    }
1131#include "RSClangBuiltinEnums.inc"
1132    default: {
1133      return llvm::StringRef();
1134    }
1135  }
1136}
1137
1138RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
1139                                               const clang::ExtVectorType *EVT,
1140                                               const llvm::StringRef &TypeName,
1141                                               bool Normalized) {
1142  slangAssert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
1143
1144  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
1145  RSExportPrimitiveType::DataType DT =
1146      RSExportPrimitiveType::GetDataType(Context, ElementType);
1147
1148  if (DT != RSExportPrimitiveType::DataTypeUnknown)
1149    return new RSExportVectorType(Context,
1150                                  TypeName,
1151                                  DT,
1152                                  Normalized,
1153                                  EVT->getNumElements());
1154  else
1155    return NULL;
1156}
1157
1158llvm::Type *RSExportVectorType::convertToLLVMType() const {
1159  llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
1160  return llvm::VectorType::get(ElementType, getNumElement());
1161}
1162
1163union RSType *RSExportVectorType::convertToSpecType() const {
1164  llvm::OwningPtr<union RSType> ST(new union RSType);
1165
1166  RS_TYPE_SET_CLASS(ST, RS_TC_Vector);
1167  RS_VECTOR_TYPE_SET_ELEMENT_TYPE(ST, getType());
1168  RS_VECTOR_TYPE_SET_VECTOR_SIZE(ST, getNumElement());
1169
1170  return ST.take();
1171}
1172
1173bool RSExportVectorType::equals(const RSExportable *E) const {
1174  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
1175  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
1176              == getNumElement());
1177}
1178
1179/***************************** RSExportMatrixType *****************************/
1180RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
1181                                               const clang::RecordType *RT,
1182                                               const llvm::StringRef &TypeName,
1183                                               unsigned Dim) {
1184  slangAssert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
1185  slangAssert((Dim > 1) && "Invalid dimension of matrix");
1186
1187  // Check whether the struct rs_matrix is in our expected form (but assume it's
1188  // correct if we're not sure whether it's correct or not)
1189  const clang::RecordDecl* RD = RT->getDecl();
1190  RD = RD->getDefinition();
1191  if (RD != NULL) {
1192    clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
1193    const clang::SourceManager *SM = Context->getSourceManager();
1194    // Find definition, perform further examination
1195    if (RD->field_empty()) {
1196      DiagEngine->Report(
1197        clang::FullSourceLoc(RD->getLocation(), *SM),
1198        DiagEngine->getCustomDiagID(
1199          clang::DiagnosticsEngine::Error,
1200          "invalid matrix struct: must have 1 field for saving values: '%0'"))
1201           << RD->getName();
1202      return NULL;
1203    }
1204
1205    clang::RecordDecl::field_iterator FIT = RD->field_begin();
1206    const clang::FieldDecl *FD = *FIT;
1207    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1208    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
1209      DiagEngine->Report(
1210        clang::FullSourceLoc(RD->getLocation(), *SM),
1211        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
1212                                    "invalid matrix struct: first field should"
1213                                    " be an array with constant size: '%0'"))
1214        << RD->getName();
1215      return NULL;
1216    }
1217    const clang::ConstantArrayType *CAT =
1218      static_cast<const clang::ConstantArrayType *>(FT);
1219    const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
1220    if ((ElementType == NULL) ||
1221        (ElementType->getTypeClass() != clang::Type::Builtin) ||
1222        (static_cast<const clang::BuiltinType *>(ElementType)->getKind() !=
1223         clang::BuiltinType::Float)) {
1224      DiagEngine->Report(
1225        clang::FullSourceLoc(RD->getLocation(), *SM),
1226        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
1227                                    "invalid matrix struct: first field "
1228                                    "should be a float array: '%0'"))
1229        << RD->getName();
1230      return NULL;
1231    }
1232
1233    if (CAT->getSize() != Dim * Dim) {
1234      DiagEngine->Report(
1235        clang::FullSourceLoc(RD->getLocation(), *SM),
1236        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
1237                                    "invalid matrix struct: first field "
1238                                    "should be an array with size %0: '%1'"))
1239        << (Dim * Dim) << (RD->getName());
1240      return NULL;
1241    }
1242
1243    FIT++;
1244    if (FIT != RD->field_end()) {
1245      DiagEngine->Report(
1246        clang::FullSourceLoc(RD->getLocation(), *SM),
1247        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
1248                                    "invalid matrix struct: must have "
1249                                    "exactly 1 field: '%0'"))
1250        << RD->getName();
1251      return NULL;
1252    }
1253  }
1254
1255  return new RSExportMatrixType(Context, TypeName, Dim);
1256}
1257
1258llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1259  // Construct LLVM type:
1260  // struct {
1261  //  float X[mDim * mDim];
1262  // }
1263
1264  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1265  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1266                                            mDim * mDim);
1267  return llvm::StructType::get(C, X, false);
1268}
1269
1270union RSType *RSExportMatrixType::convertToSpecType() const {
1271  llvm::OwningPtr<union RSType> ST(new union RSType);
1272  RS_TYPE_SET_CLASS(ST, RS_TC_Matrix);
1273  switch (getDim()) {
1274    case 2: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix2x2); break;
1275    case 3: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix3x3); break;
1276    case 4: RS_MATRIX_TYPE_SET_DATA_TYPE(ST, RS_DT_RSMatrix4x4); break;
1277    default: slangAssert(false && "Matrix type with unsupported dimension.");
1278  }
1279  return ST.take();
1280}
1281
1282bool RSExportMatrixType::equals(const RSExportable *E) const {
1283  CHECK_PARENT_EQUALITY(RSExportType, E);
1284  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1285}
1286
1287/************************* RSExportConstantArrayType *************************/
1288RSExportConstantArrayType
1289*RSExportConstantArrayType::Create(RSContext *Context,
1290                                   const clang::ConstantArrayType *CAT) {
1291  slangAssert(CAT != NULL && CAT->getTypeClass() == clang::Type::ConstantArray);
1292
1293  slangAssert((CAT->getSize().getActiveBits() < 32) && "array too large");
1294
1295  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1296  slangAssert((Size > 0) && "Constant array should have size greater than 0");
1297
1298  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
1299  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
1300
1301  if (ElementET == NULL) {
1302    return NULL;
1303  }
1304
1305  return new RSExportConstantArrayType(Context,
1306                                       ElementET,
1307                                       Size);
1308}
1309
1310llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1311  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
1312}
1313
1314union RSType *RSExportConstantArrayType::convertToSpecType() const {
1315  llvm::OwningPtr<union RSType> ST(new union RSType);
1316
1317  RS_TYPE_SET_CLASS(ST, RS_TC_ConstantArray);
1318  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_TYPE(
1319      ST, getElementType()->getSpecType());
1320  RS_CONSTANT_ARRAY_TYPE_SET_ELEMENT_SIZE(ST, getSize());
1321
1322  if (RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(ST) != NULL)
1323    return ST.take();
1324  else
1325    return NULL;
1326}
1327
1328bool RSExportConstantArrayType::keep() {
1329  if (!RSExportType::keep())
1330    return false;
1331  const_cast<RSExportType*>(mElementType)->keep();
1332  return true;
1333}
1334
1335bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1336  CHECK_PARENT_EQUALITY(RSExportType, E);
1337  const RSExportConstantArrayType *RHS =
1338      static_cast<const RSExportConstantArrayType*>(E);
1339  return ((getSize() == RHS->getSize()) &&
1340          (getElementType()->equals(RHS->getElementType())));
1341}
1342
1343/**************************** RSExportRecordType ****************************/
1344RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1345                                               const clang::RecordType *RT,
1346                                               const llvm::StringRef &TypeName,
1347                                               bool mIsArtificial) {
1348  slangAssert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
1349
1350  const clang::RecordDecl *RD = RT->getDecl();
1351  slangAssert(RD->isStruct());
1352
1353  RD = RD->getDefinition();
1354  if (RD == NULL) {
1355    slangAssert(false && "struct is not defined in this module");
1356    return NULL;
1357  }
1358
1359  // Struct layout construct by clang. We rely on this for obtaining the
1360  // alloc size of a struct and offset of every field in that struct.
1361  const clang::ASTRecordLayout *RL =
1362      &Context->getASTContext().getASTRecordLayout(RD);
1363  slangAssert((RL != NULL) &&
1364      "Failed to retrieve the struct layout from Clang.");
1365
1366  RSExportRecordType *ERT =
1367      new RSExportRecordType(Context,
1368                             TypeName,
1369                             RD->hasAttr<clang::PackedAttr>(),
1370                             mIsArtificial,
1371                             RL->getSize().getQuantity());
1372  unsigned int Index = 0;
1373
1374  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1375           FE = RD->field_end();
1376       FI != FE;
1377       FI++, Index++) {
1378    clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
1379
1380    // FIXME: All fields should be primitive type
1381    slangAssert((*FI)->getKind() == clang::Decl::Field);
1382    clang::FieldDecl *FD = *FI;
1383
1384    if (FD->isBitField()) {
1385      return NULL;
1386    }
1387
1388    // Type
1389    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1390
1391    if (ET != NULL) {
1392      ERT->mFields.push_back(
1393          new Field(ET, FD->getName(), ERT,
1394                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1395    } else {
1396      DiagEngine->Report(
1397        clang::FullSourceLoc(RD->getLocation(), DiagEngine->getSourceManager()),
1398        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
1399                                    "field type cannot be exported: '%0.%1'"))
1400        << RD->getName() << FD->getName();
1401      return NULL;
1402    }
1403  }
1404
1405  return ERT;
1406}
1407
1408llvm::Type *RSExportRecordType::convertToLLVMType() const {
1409  // Create an opaque type since struct may reference itself recursively.
1410
1411  // TODO(sliao): LLVM took out the OpaqueType. Any other to migrate to?
1412  std::vector<llvm::Type*> FieldTypes;
1413
1414  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1415       FI != FE;
1416       FI++) {
1417    const Field *F = *FI;
1418    const RSExportType *FET = F->getType();
1419
1420    FieldTypes.push_back(FET->getLLVMType());
1421  }
1422
1423  llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1424                                               FieldTypes,
1425                                               mIsPacked);
1426  if (ST != NULL) {
1427    return ST;
1428  } else {
1429    return NULL;
1430  }
1431}
1432
1433union RSType *RSExportRecordType::convertToSpecType() const {
1434  unsigned NumFields = getFields().size();
1435  unsigned AllocSize = sizeof(union RSType) +
1436                       sizeof(struct RSRecordField) * NumFields;
1437  llvm::OwningPtr<union RSType> ST(
1438      reinterpret_cast<union RSType*>(operator new(AllocSize)));
1439
1440  ::memset(ST.get(), 0, AllocSize);
1441
1442  RS_TYPE_SET_CLASS(ST, RS_TC_Record);
1443  RS_RECORD_TYPE_SET_NAME(ST, getName().c_str());
1444  RS_RECORD_TYPE_SET_NUM_FIELDS(ST, NumFields);
1445
1446  setSpecTypeTemporarily(ST.get());
1447
1448  unsigned FieldIdx = 0;
1449  for (const_field_iterator FI = fields_begin(), FE = fields_end();
1450       FI != FE;
1451       FI++, FieldIdx++) {
1452    const Field *F = *FI;
1453
1454    RS_RECORD_TYPE_SET_FIELD_NAME(ST, FieldIdx, F->getName().c_str());
1455    RS_RECORD_TYPE_SET_FIELD_TYPE(ST, FieldIdx, F->getType()->getSpecType());
1456  }
1457
1458  // TODO(slang): Check whether all fields were created normally.
1459
1460  return ST.take();
1461}
1462
1463bool RSExportRecordType::keep() {
1464  if (!RSExportType::keep())
1465    return false;
1466  for (std::list<const Field*>::iterator I = mFields.begin(),
1467          E = mFields.end();
1468       I != E;
1469       I++) {
1470    const_cast<RSExportType*>((*I)->getType())->keep();
1471  }
1472  return true;
1473}
1474
1475bool RSExportRecordType::equals(const RSExportable *E) const {
1476  CHECK_PARENT_EQUALITY(RSExportType, E);
1477
1478  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1479
1480  if (ERT->getFields().size() != getFields().size())
1481    return false;
1482
1483  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1484
1485  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1486    if (!(*AI)->getType()->equals((*BI)->getType()))
1487      return false;
1488    AI++;
1489    BI++;
1490  }
1491
1492  return true;
1493}
1494
1495void RSExportType::convertToRTD(RSReflectionTypeData *rtd) const {
1496    memset(rtd, 0, sizeof(*rtd));
1497    rtd->vecSize = 1;
1498
1499    switch(getClass()) {
1500    case RSExportType::ExportClassPrimitive: {
1501            const RSExportPrimitiveType *EPT = static_cast<const RSExportPrimitiveType*>(this);
1502            rtd->type = RSExportPrimitiveType::getRSReflectionType(EPT);
1503            return;
1504        }
1505    case RSExportType::ExportClassPointer: {
1506            const RSExportPointerType *EPT = static_cast<const RSExportPointerType*>(this);
1507            const RSExportType *PointeeType = EPT->getPointeeType();
1508            PointeeType->convertToRTD(rtd);
1509            rtd->isPointer = true;
1510            return;
1511        }
1512    case RSExportType::ExportClassVector: {
1513            const RSExportVectorType *EVT = static_cast<const RSExportVectorType*>(this);
1514            rtd->type = EVT->getRSReflectionType(EVT);
1515            rtd->vecSize = EVT->getNumElement();
1516            return;
1517        }
1518    case RSExportType::ExportClassMatrix: {
1519            const RSExportMatrixType *EMT = static_cast<const RSExportMatrixType*>(this);
1520            unsigned Dim = EMT->getDim();
1521            slangAssert((Dim >= 2) && (Dim <= 4));
1522            rtd->type = &gReflectionTypes[15 + Dim-2];
1523            return;
1524        }
1525    case RSExportType::ExportClassConstantArray: {
1526            const RSExportConstantArrayType* CAT =
1527              static_cast<const RSExportConstantArrayType*>(this);
1528            CAT->getElementType()->convertToRTD(rtd);
1529            rtd->arraySize = CAT->getSize();
1530            return;
1531        }
1532    case RSExportType::ExportClassRecord: {
1533            slangAssert(!"RSExportType::ExportClassRecord not implemented");
1534            return;// RS_TYPE_CLASS_NAME_PREFIX + ET->getName() + ".Item";
1535        }
1536    default: {
1537            slangAssert(false && "Unknown class of type");
1538        }
1539    }
1540}
1541
1542
1543}  // namespace slang
1544