slang_rs_export_type.cpp revision 41ebf534161bb67f6207a070c1f6a895dc853408
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_type.h"
18
19#include <vector>
20
21#include "llvm/Type.h"
22#include "llvm/DerivedTypes.h"
23
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/Target/TargetData.h"
26
27#include "clang/AST/RecordLayout.h"
28
29#include "slang_rs_context.h"
30#include "slang_rs_export_element.h"
31
32#define CHECK_PARENT_EQUALITY(ParentClass, E) \
33  if (!ParentClass::equals(E))                \
34    return false;
35
36using namespace slang;
37
38/****************************** RSExportType ******************************/
39bool RSExportType::NormalizeType(const clang::Type *&T,
40                                 llvm::StringRef &TypeName) {
41  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
42      llvm::SmallPtrSet<const clang::Type*, 8>();
43
44  if ((T = RSExportType::TypeExportable(T, SPS)) == NULL)
45    // TODO(zonr): warn that type not exportable.
46    return false;
47
48  // Get type name
49  TypeName = RSExportType::GetTypeName(T);
50  if (TypeName.empty())
51    // TODO(zonr): warning that the type is unnamed.
52    return false;
53
54  return true;
55}
56
57const clang::Type
58*RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
59  if (DD) {
60    clang::QualType T;
61    if (DD->getTypeSourceInfo())
62      T = DD->getTypeSourceInfo()->getType();
63    else
64      T = DD->getType();
65
66    if (T.isNull())
67      return NULL;
68    else
69      return T.getTypePtr();
70  }
71  return NULL;
72}
73
74llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
75  T = GET_CANONICAL_TYPE(T);
76  if (T == NULL)
77    return llvm::StringRef();
78
79  switch (T->getTypeClass()) {
80    case clang::Type::Builtin: {
81      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
82
83      switch (BT->getKind()) {
84        // Compiler is smart enough to optimize following *big if branches*
85        // since they all become "constant comparison" after macro expansion
86#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
87        case builtin_type: {                                    \
88          if (type == RSExportPrimitiveType::DataTypeFloat32)           \
89            return "float";                                             \
90          else if (type == RSExportPrimitiveType::DataTypeFloat64)      \
91            return "double";                                            \
92          else if (type == RSExportPrimitiveType::DataTypeUnsigned8)    \
93            return "uchar";                                             \
94          else if (type == RSExportPrimitiveType::DataTypeUnsigned16)   \
95            return "ushort";                                            \
96          else if (type == RSExportPrimitiveType::DataTypeUnsigned32)   \
97            return "uint";                                              \
98          else if (type == RSExportPrimitiveType::DataTypeUnsigned64)   \
99            return "ulong";                                             \
100          else if (type == RSExportPrimitiveType::DataTypeSigned8)      \
101            return "char";                                              \
102          else if (type == RSExportPrimitiveType::DataTypeSigned16)     \
103            return "short";                                             \
104          else if (type == RSExportPrimitiveType::DataTypeSigned32)     \
105            return "int";                                               \
106          else if (type == RSExportPrimitiveType::DataTypeSigned64)     \
107            return "long";                                              \
108          else if (type == RSExportPrimitiveType::DataTypeBoolean)      \
109            return "bool";                                              \
110          else                                                          \
111            assert(false && "Unknow data type of supported builtin");   \
112          break;                                                        \
113        }
114#include "slang_rs_export_type_support.inc"
115
116          default: {
117            assert(false && "Unknown data type of the builtin");
118            break;
119          }
120        }
121      break;
122    }
123    case clang::Type::Record: {
124      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
125      llvm::StringRef Name = RD->getName();
126      if (Name.empty()) {
127          if (RD->getTypedefForAnonDecl() != NULL)
128            Name = RD->getTypedefForAnonDecl()->getName();
129
130          if (Name.empty())
131            // Try to find a name from redeclaration (i.e. typedef)
132            for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
133                     RE = RD->redecls_end();
134                 RI != RE;
135                 RI++) {
136              assert(*RI != NULL && "cannot be NULL object");
137
138              Name = (*RI)->getName();
139              if (!Name.empty())
140                break;
141            }
142      }
143      return Name;
144    }
145    case clang::Type::Pointer: {
146      // "*" plus pointee name
147      const clang::Type *PT = GET_POINTEE_TYPE(T);
148      llvm::StringRef PointeeName;
149      if (NormalizeType(PT, PointeeName)) {
150        char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
151        Name[0] = '*';
152        memcpy(Name + 1, PointeeName.data(), PointeeName.size());
153        Name[PointeeName.size() + 1] = '\0';
154        return Name;
155      }
156      break;
157    }
158    case clang::Type::ExtVector: {
159      const clang::ExtVectorType *EVT =
160          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
161      return RSExportVectorType::GetTypeName(EVT);
162      break;
163    }
164    case clang::Type::ConstantArray : {
165      // Construct name for a constant array is too complicated.
166      return DUMMY_TYPE_NAME_FOR_RS_CONSTANT_ARRAY_TYPE;
167    }
168    default: {
169      break;
170    }
171  }
172
173  return llvm::StringRef();
174}
175
176const clang::Type *RSExportType::TypeExportable(
177    const clang::Type *T,
178    llvm::SmallPtrSet<const clang::Type*, 8>& SPS) {
179  // Normalize first
180  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
181    return NULL;
182
183  if (SPS.count(T))
184    return T;
185
186  switch (T->getTypeClass()) {
187    case clang::Type::Builtin: {
188      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
189
190      switch (BT->getKind()) {
191#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
192        case builtin_type:
193#include "slang_rs_export_type_support.inc"
194        {
195          return T;
196        }
197        default: {
198          return NULL;
199        }
200      }
201      // Never be here
202    }
203    case clang::Type::Record: {
204      if (RSExportPrimitiveType::GetRSObjectType(T) !=
205          RSExportPrimitiveType::DataTypeUnknown)
206        return T;  // RS object type, no further checks are needed
207
208      // Check internal struct
209      const clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
210      if (RD != NULL)
211        RD = RD->getDefinition();
212
213      // Fast check
214      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
215        return NULL;
216
217      // Insert myself into checking set
218      SPS.insert(T);
219
220      // Check all element
221      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
222               FE = RD->field_end();
223           FI != FE;
224           FI++) {
225        const clang::FieldDecl *FD = *FI;
226        const clang::Type *FT = GetTypeOfDecl(FD);
227        FT = GET_CANONICAL_TYPE(FT);
228
229        if (!TypeExportable(FT, SPS)) {
230          fprintf(stderr, "Field `%s' in Record `%s' contains unsupported "
231                          "type\n", FD->getNameAsString().c_str(),
232                                    RD->getNameAsString().c_str());
233          FT->dump();
234          return NULL;
235        }
236      }
237
238      return T;
239    }
240    case clang::Type::Pointer: {
241      const clang::PointerType *PT = UNSAFE_CAST_TYPE(clang::PointerType, T);
242      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
243
244      if (PointeeType->getTypeClass() == clang::Type::Pointer)
245        return T;
246      // We don't support pointer with array-type pointee or unsupported pointee
247      // type
248      if (PointeeType->isArrayType() ||
249         (TypeExportable(PointeeType, SPS) == NULL) )
250        return NULL;
251      else
252        return T;
253    }
254    case clang::Type::ExtVector: {
255      const clang::ExtVectorType *EVT =
256          UNSAFE_CAST_TYPE(clang::ExtVectorType, T);
257      // Only vector with size 2, 3 and 4 are supported.
258      if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
259        return NULL;
260
261      // Check base element type
262      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
263
264      if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
265          (TypeExportable(ElementType, SPS) == NULL))
266        return NULL;
267      else
268        return T;
269    }
270    case clang::Type::ConstantArray: {
271      const clang::ConstantArrayType *CAT =
272          UNSAFE_CAST_TYPE(clang::ConstantArrayType, T);
273
274      // Check size
275      if (CAT->getSize().getActiveBits() > 32) {
276        fprintf(stderr, "RSExportConstantArrayType::Create : array with too "
277                        "large size (> 2^32).\n");
278        return NULL;
279      }
280      // Check element type
281      const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
282      if (ElementType->isArrayType()) {
283        fprintf(stderr, "RSExportType::TypeExportable : constant array with 2 "
284                        "or higher dimension of constant is not supported.\n");
285        return NULL;
286      }
287      if (TypeExportable(ElementType, SPS) == NULL)
288        return NULL;
289      else
290        return T;
291    }
292    default: {
293      return NULL;
294    }
295  }
296}
297
298RSExportType *RSExportType::Create(RSContext *Context,
299                                   const clang::Type *T,
300                                   const llvm::StringRef &TypeName) {
301  // Lookup the context to see whether the type was processed before.
302  // Newly created RSExportType will insert into context
303  // in RSExportType::RSExportType()
304  RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
305
306  if (ETI != Context->export_types_end())
307    return ETI->second;
308
309  RSExportType *ET = NULL;
310  switch (T->getTypeClass()) {
311    case clang::Type::Record: {
312      RSExportPrimitiveType::DataType dt =
313          RSExportPrimitiveType::GetRSObjectType(TypeName);
314      switch (dt) {
315        case RSExportPrimitiveType::DataTypeUnknown: {
316          // User-defined types
317          ET = RSExportRecordType::Create(Context,
318                                          T->getAsStructureType(),
319                                          TypeName);
320          break;
321        }
322        case RSExportPrimitiveType::DataTypeRSMatrix2x2: {
323          // 2 x 2 Matrix type
324          ET = RSExportMatrixType::Create(Context,
325                                          T->getAsStructureType(),
326                                          TypeName,
327                                          2);
328          break;
329        }
330        case RSExportPrimitiveType::DataTypeRSMatrix3x3: {
331          // 3 x 3 Matrix type
332          ET = RSExportMatrixType::Create(Context,
333                                          T->getAsStructureType(),
334                                          TypeName,
335                                          3);
336          break;
337        }
338        case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
339          // 4 x 4 Matrix type
340          ET = RSExportMatrixType::Create(Context,
341                                          T->getAsStructureType(),
342                                          TypeName,
343                                          4);
344          break;
345        }
346        default: {
347          // Others are primitive types
348          ET = RSExportPrimitiveType::Create(Context, T, TypeName);
349          break;
350        }
351      }
352      break;
353    }
354    case clang::Type::Builtin: {
355      ET = RSExportPrimitiveType::Create(Context, T, TypeName);
356      break;
357    }
358    case clang::Type::Pointer: {
359      ET = RSExportPointerType::Create(Context,
360                                       UNSAFE_CAST_TYPE(clang::PointerType, T),
361                                       TypeName);
362      // FIXME: free the name (allocated in RSExportType::GetTypeName)
363      delete [] TypeName.data();
364      break;
365    }
366    case clang::Type::ExtVector: {
367      ET = RSExportVectorType::Create(Context,
368                                      UNSAFE_CAST_TYPE(clang::ExtVectorType, T),
369                                      TypeName);
370      break;
371    }
372    case clang::Type::ConstantArray: {
373      ET = RSExportConstantArrayType::Create(
374              Context,
375              UNSAFE_CAST_TYPE(clang::ConstantArrayType, T));
376      break;
377    }
378    default: {
379      // TODO(zonr): warn that type is not exportable.
380      fprintf(stderr,
381              "RSExportType::Create : type '%s' is not exportable\n",
382              T->getTypeClassName());
383      break;
384    }
385  }
386
387  return ET;
388}
389
390RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
391  llvm::StringRef TypeName;
392  if (NormalizeType(T, TypeName))
393    return Create(Context, T, TypeName);
394  else
395    return NULL;
396}
397
398RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
399                                           const clang::VarDecl *VD) {
400  return RSExportType::Create(Context, GetTypeOfDecl(VD));
401}
402
403size_t RSExportType::GetTypeStoreSize(const RSExportType *ET) {
404  return ET->getRSContext()->getTargetData()->getTypeStoreSize(
405      ET->getLLVMType());
406}
407
408size_t RSExportType::GetTypeAllocSize(const RSExportType *ET) {
409  if (ET->getClass() == RSExportType::ExportClassRecord)
410    return static_cast<const RSExportRecordType*>(ET)->getAllocSize();
411  else
412    return ET->getRSContext()->getTargetData()->getTypeAllocSize(
413        ET->getLLVMType());
414}
415
416RSExportType::RSExportType(RSContext *Context,
417                           ExportClass Class,
418                           const llvm::StringRef &Name)
419    : RSExportable(Context, RSExportable::EX_TYPE),
420      mClass(Class),
421      // Make a copy on Name since memory stored @Name is either allocated in
422      // ASTContext or allocated in GetTypeName which will be destroyed later.
423      mName(Name.data(), Name.size()),
424      mLLVMType(NULL) {
425  // Don't cache the type whose name start with '<'. Those type failed to
426  // get their name since constructing their name in GetTypeName() requiring
427  // complicated work.
428  if (!Name.startswith(DUMMY_RS_TYPE_NAME_PREFIX))
429    // TODO(zonr): Need to check whether the insertion is successful or not.
430    Context->insertExportType(llvm::StringRef(Name), this);
431  return;
432}
433
434void RSExportType::keep() {
435  // Invalidate converted LLVM type.
436  mLLVMType = NULL;
437  RSExportable::keep();
438  return;
439}
440
441bool RSExportType::equals(const RSExportable *E) const {
442  CHECK_PARENT_EQUALITY(RSExportable, E);
443  return (static_cast<const RSExportType*>(E)->getClass() == getClass());
444}
445
446/************************** RSExportPrimitiveType **************************/
447llvm::ManagedStatic<RSExportPrimitiveType::RSObjectTypeMapTy>
448RSExportPrimitiveType::RSObjectTypeMap;
449
450llvm::Type *RSExportPrimitiveType::RSObjectLLVMType = NULL;
451
452bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
453  if ((T != NULL) && (T->getTypeClass() == clang::Type::Builtin))
454    return true;
455  else
456    return false;
457}
458
459RSExportPrimitiveType::DataType
460RSExportPrimitiveType::GetRSObjectType(const llvm::StringRef &TypeName) {
461  if (TypeName.empty())
462    return DataTypeUnknown;
463
464  if (RSObjectTypeMap->empty()) {
465#define USE_ELEMENT_DATA_TYPE
466#define DEF_RS_OBJECT_TYPE(type, name)                                  \
467    RSObjectTypeMap->GetOrCreateValue(name, GET_ELEMENT_DATA_TYPE(type));
468#include "slang_rs_export_element_support.inc"
469  }
470
471  RSObjectTypeMapTy::const_iterator I = RSObjectTypeMap->find(TypeName);
472  if (I == RSObjectTypeMap->end())
473    return DataTypeUnknown;
474  else
475    return I->getValue();
476}
477
478RSExportPrimitiveType::DataType
479RSExportPrimitiveType::GetRSObjectType(const clang::Type *T) {
480  T = GET_CANONICAL_TYPE(T);
481  if ((T == NULL) || (T->getTypeClass() != clang::Type::Record))
482    return DataTypeUnknown;
483
484  return GetRSObjectType( RSExportType::GetTypeName(T) );
485}
486
487const size_t
488RSExportPrimitiveType::SizeOfDataTypeInBits[
489    RSExportPrimitiveType::DataTypeMax + 1] = {
490  16,  // DataTypeFloat16
491  32,  // DataTypeFloat32
492  64,  // DataTypeFloat64
493  8,   // DataTypeSigned8
494  16,  // DataTypeSigned16
495  32,  // DataTypeSigned32
496  64,  // DataTypeSigned64
497  8,   // DataTypeUnsigned8
498  16,  // DataTypeUnsigned16
499  32,  // DataTypeUnsigned32
500  64,  // DataTypeUnSigned64
501  1,   // DataTypeBoolean
502
503  16,  // DataTypeUnsigned565
504  16,  // DataTypeUnsigned5551
505  16,  // DataTypeUnsigned4444
506
507  128,  // DataTypeRSMatrix2x2
508  288,  // DataTypeRSMatrix3x3
509  512,  // DataTypeRSMatrix4x4
510
511  32,  // DataTypeRSElement
512  32,  // DataTypeRSType
513  32,  // DataTypeRSAllocation
514  32,  // DataTypeRSSampler
515  32,  // DataTypeRSScript
516  32,  // DataTypeRSMesh
517  32,  // DataTypeRSProgramFragment
518  32,  // DataTypeRSProgramVertex
519  32,  // DataTypeRSProgramRaster
520  32,  // DataTypeRSProgramStore
521  32,  // DataTypeRSFont
522  0
523};
524
525size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
526  assert(((EPT->getType() >= DataTypeFloat32) &&
527          (EPT->getType() < DataTypeMax)) &&
528         "RSExportPrimitiveType::GetSizeInBits : unknown data type");
529  return SizeOfDataTypeInBits[ static_cast<int>(EPT->getType()) ];
530}
531
532RSExportPrimitiveType::DataType
533RSExportPrimitiveType::GetDataType(const clang::Type *T) {
534  if (T == NULL)
535    return DataTypeUnknown;
536
537  switch (T->getTypeClass()) {
538    case clang::Type::Builtin: {
539      const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType, T);
540      switch (BT->getKind()) {
541#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
542        case builtin_type: {                                    \
543          return type;                                          \
544          break;                                                \
545        }
546#include "slang_rs_export_type_support.inc"
547
548        // The size of types Long, ULong and WChar depend on platform so we
549        // abandon the support to them. Type of its size exceeds 32 bits (e.g.
550        // int64_t, double, etc.): no support
551
552        default: {
553          // TODO(zonr): warn that the type is unsupported
554          fprintf(stderr, "RSExportPrimitiveType::GetDataType : built-in type "
555                          "has no corresponding data type for built-in type");
556          break;
557        }
558      }
559      break;
560    }
561
562    case clang::Type::Record: {
563      // must be RS object type
564      return RSExportPrimitiveType::GetRSObjectType(T);
565      break;
566    }
567
568    default: {
569      fprintf(stderr, "RSExportPrimitiveType::GetDataType : type '%s' is not "
570                      "supported primitive type", T->getTypeClassName());
571      break;
572    }
573  }
574
575  return DataTypeUnknown;
576}
577
578RSExportPrimitiveType
579*RSExportPrimitiveType::Create(RSContext *Context,
580                               const clang::Type *T,
581                               const llvm::StringRef &TypeName,
582                               DataKind DK,
583                               bool Normalized) {
584  DataType DT = GetDataType(T);
585
586  if ((DT == DataTypeUnknown) || TypeName.empty())
587    return NULL;
588  else
589    return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
590                                     DT, DK, Normalized);
591}
592
593RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
594                                                     const clang::Type *T,
595                                                     DataKind DK) {
596  llvm::StringRef TypeName;
597  if (RSExportType::NormalizeType(T, TypeName) && IsPrimitiveType(T))
598    return Create(Context, T, TypeName, DK);
599  else
600    return NULL;
601}
602
603const llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
604  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
605
606  if (isRSObjectType()) {
607    // struct {
608    //   int *p;
609    // } __attribute__((packed, aligned(pointer_size)))
610    //
611    // which is
612    //
613    // <{ [1 x i32] }> in LLVM
614    //
615    if (RSObjectLLVMType == NULL) {
616      std::vector<const llvm::Type *> Elements;
617      Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
618      RSObjectLLVMType = llvm::StructType::get(C, Elements, true);
619    }
620    return RSObjectLLVMType;
621  }
622
623  switch (mType) {
624    case DataTypeFloat32: {
625      return llvm::Type::getFloatTy(C);
626      break;
627    }
628    case DataTypeFloat64: {
629      return llvm::Type::getDoubleTy(C);
630      break;
631    }
632    case DataTypeBoolean: {
633      return llvm::Type::getInt1Ty(C);
634      break;
635    }
636    case DataTypeSigned8:
637    case DataTypeUnsigned8: {
638      return llvm::Type::getInt8Ty(C);
639      break;
640    }
641    case DataTypeSigned16:
642    case DataTypeUnsigned16:
643    case DataTypeUnsigned565:
644    case DataTypeUnsigned5551:
645    case DataTypeUnsigned4444: {
646      return llvm::Type::getInt16Ty(C);
647      break;
648    }
649    case DataTypeSigned32:
650    case DataTypeUnsigned32: {
651      return llvm::Type::getInt32Ty(C);
652      break;
653    }
654    case DataTypeSigned64:
655    case DataTypeUnsigned64: {
656      return llvm::Type::getInt64Ty(C);
657      break;
658    }
659    default: {
660      assert(false && "Unknown data type");
661    }
662  }
663
664  return NULL;
665}
666
667bool RSExportPrimitiveType::equals(const RSExportable *E) const {
668  CHECK_PARENT_EQUALITY(RSExportType, E);
669  return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
670}
671
672/**************************** RSExportPointerType ****************************/
673
674const clang::Type *RSExportPointerType::IntegerType = NULL;
675
676RSExportPointerType
677*RSExportPointerType::Create(RSContext *Context,
678                             const clang::PointerType *PT,
679                             const llvm::StringRef &TypeName) {
680  const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
681  const RSExportType *PointeeET;
682
683  if (PointeeType->getTypeClass() != clang::Type::Pointer) {
684    PointeeET = RSExportType::Create(Context, PointeeType);
685  } else {
686    // Double or higher dimension of pointer, export as int*
687    assert(IntegerType != NULL && "Built-in integer type is not set");
688    PointeeET = RSExportPrimitiveType::Create(Context, IntegerType);
689  }
690
691  if (PointeeET == NULL) {
692    fprintf(stderr, "Failed to create type for pointee");
693    return NULL;
694  }
695
696  return new RSExportPointerType(Context, TypeName, PointeeET);
697}
698
699const llvm::Type *RSExportPointerType::convertToLLVMType() const {
700  const llvm::Type *PointeeType = mPointeeType->getLLVMType();
701  return llvm::PointerType::getUnqual(PointeeType);
702}
703
704void RSExportPointerType::keep() {
705  const_cast<RSExportType*>(mPointeeType)->keep();
706  RSExportType::keep();
707}
708
709bool RSExportPointerType::equals(const RSExportable *E) const {
710  CHECK_PARENT_EQUALITY(RSExportType, E);
711  return (static_cast<const RSExportPointerType*>(E)
712              ->getPointeeType()->equals(getPointeeType()));
713}
714
715/***************************** RSExportVectorType *****************************/
716const char* RSExportVectorType::VectorTypeNameStore[][3] = {
717  /* 0 */ { "char2",      "char3",    "char4" },
718  /* 1 */ { "uchar2",     "uchar3",   "uchar4" },
719  /* 2 */ { "short2",     "short3",   "short4" },
720  /* 3 */ { "ushort2",    "ushort3",  "ushort4" },
721  /* 4 */ { "int2",       "int3",     "int4" },
722  /* 5 */ { "uint2",      "uint3",    "uint4" },
723  /* 6 */ { "long2",      "long3",    "long4" },
724  /* 7 */ { "ulong2",     "ulong3",   "ulong4" },
725  /* 8 */ { "float2",     "float3",   "float4" },
726  /* 9 */ { "double2",    "double3",  "double4" },
727};
728
729llvm::StringRef
730RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
731  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
732
733  if ((ElementType->getTypeClass() != clang::Type::Builtin))
734    return llvm::StringRef();
735
736  const clang::BuiltinType *BT = UNSAFE_CAST_TYPE(clang::BuiltinType,
737                                                  ElementType);
738  const char **BaseElement = NULL;
739
740  switch (BT->getKind()) {
741    // Compiler is smart enough to optimize following *big if branches* since
742    // they all become "constant comparison" after macro expansion
743#define SLANG_RS_SUPPORT_BUILTIN_TYPE(builtin_type, type)       \
744    case builtin_type: {                                                \
745      if (type == RSExportPrimitiveType::DataTypeSigned8) \
746        BaseElement = VectorTypeNameStore[0];                           \
747      else if (type == RSExportPrimitiveType::DataTypeUnsigned8) \
748        BaseElement = VectorTypeNameStore[1];                           \
749      else if (type == RSExportPrimitiveType::DataTypeSigned16) \
750        BaseElement = VectorTypeNameStore[2];                           \
751      else if (type == RSExportPrimitiveType::DataTypeUnsigned16) \
752        BaseElement = VectorTypeNameStore[3];                           \
753      else if (type == RSExportPrimitiveType::DataTypeSigned32) \
754        BaseElement = VectorTypeNameStore[4];                           \
755      else if (type == RSExportPrimitiveType::DataTypeUnsigned32) \
756        BaseElement = VectorTypeNameStore[5];                           \
757      else if (type == RSExportPrimitiveType::DataTypeSigned64) \
758        BaseElement = VectorTypeNameStore[6];                           \
759      else if (type == RSExportPrimitiveType::DataTypeUnsigned64) \
760        BaseElement = VectorTypeNameStore[7];                           \
761      else if (type == RSExportPrimitiveType::DataTypeFloat32) \
762        BaseElement = VectorTypeNameStore[8];                           \
763      else if (type == RSExportPrimitiveType::DataTypeFloat64) \
764        BaseElement = VectorTypeNameStore[9];                           \
765      else if (type == RSExportPrimitiveType::DataTypeBoolean) \
766        BaseElement = VectorTypeNameStore[0];                          \
767      break;  \
768    }
769#include "slang_rs_export_type_support.inc"
770    default: {
771      return llvm::StringRef();
772    }
773  }
774
775  if ((BaseElement != NULL) &&
776      (EVT->getNumElements() > 1) &&
777      (EVT->getNumElements() <= 4))
778    return BaseElement[EVT->getNumElements() - 2];
779  else
780    return llvm::StringRef();
781}
782
783RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
784                                               const clang::ExtVectorType *EVT,
785                                               const llvm::StringRef &TypeName,
786                                               DataKind DK,
787                                               bool Normalized) {
788  assert(EVT != NULL && EVT->getTypeClass() == clang::Type::ExtVector);
789
790  const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
791  RSExportPrimitiveType::DataType DT =
792      RSExportPrimitiveType::GetDataType(ElementType);
793
794  if (DT != RSExportPrimitiveType::DataTypeUnknown)
795    return new RSExportVectorType(Context,
796                                  TypeName,
797                                  DT,
798                                  DK,
799                                  Normalized,
800                                  EVT->getNumElements());
801  else
802    fprintf(stderr, "RSExportVectorType::Create : unsupported base element "
803                    "type\n");
804  return NULL;
805}
806
807const llvm::Type *RSExportVectorType::convertToLLVMType() const {
808  const llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
809  return llvm::VectorType::get(ElementType, getNumElement());
810}
811
812bool RSExportVectorType::equals(const RSExportable *E) const {
813  CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
814  return (static_cast<const RSExportVectorType*>(E)->getNumElement()
815              == getNumElement());
816}
817
818/***************************** RSExportMatrixType *****************************/
819RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
820                                               const clang::RecordType *RT,
821                                               const llvm::StringRef &TypeName,
822                                               unsigned Dim) {
823  assert((RT != NULL) && (RT->getTypeClass() == clang::Type::Record));
824  assert((Dim > 1) && "Invalid dimension of matrix");
825
826  // Check whether the struct rs_matrix is in our expected form (but assume it's
827  // correct if we're not sure whether it's correct or not)
828  const clang::RecordDecl* RD = RT->getDecl();
829  RD = RD->getDefinition();
830  if (RD != NULL) {
831    // Find definition, perform further examination
832    if (RD->field_empty()) {
833      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
834                      "must have 1 field for saving values", TypeName.data());
835      return NULL;
836    }
837
838    clang::RecordDecl::field_iterator FIT = RD->field_begin();
839    const clang::FieldDecl *FD = *FIT;
840    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
841    if ((FT == NULL) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
842      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
843                      "first field should be an array with constant size",
844              TypeName.data());
845      return NULL;
846    }
847    const clang::ConstantArrayType *CAT =
848      static_cast<const clang::ConstantArrayType *>(FT);
849    const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
850    if ((ElementType == NULL) ||
851        (ElementType->getTypeClass() != clang::Type::Builtin) ||
852        (static_cast<const clang::BuiltinType *>(ElementType)->getKind()
853          != clang::BuiltinType::Float)) {
854      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
855                      "first field should be a float array", TypeName.data());
856      return NULL;
857    }
858
859    if (CAT->getSize() != Dim * Dim) {
860      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
861                      "first field should be an array with size %d",
862              TypeName.data(), Dim * Dim);
863      return NULL;
864    }
865
866    FIT++;
867    if (FIT != RD->field_end()) {
868      fprintf(stderr, "RSExportMatrixType::Create : invalid %s struct: "
869                      "must have exactly 1 field", TypeName.data());
870      return NULL;
871    }
872  }
873
874  return new RSExportMatrixType(Context, TypeName, Dim);
875}
876
877const llvm::Type *RSExportMatrixType::convertToLLVMType() const {
878  // Construct LLVM type:
879  // struct {
880  //  float X[mDim * mDim];
881  // }
882
883  llvm::LLVMContext &C = getRSContext()->getLLVMContext();
884  llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
885                                            mDim * mDim);
886  return llvm::StructType::get(C, X, NULL);
887}
888
889bool RSExportMatrixType::equals(const RSExportable *E) const {
890  CHECK_PARENT_EQUALITY(RSExportType, E);
891  return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
892}
893
894/************************* RSExportConstantArrayType *************************/
895RSExportConstantArrayType
896*RSExportConstantArrayType::Create(RSContext *Context,
897                                   const clang::ConstantArrayType *CAT) {
898  assert(CAT != NULL && CAT->getTypeClass() == clang::Type::ConstantArray);
899
900  assert((CAT->getSize().getActiveBits() < 32) && "array too large");
901
902  unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
903  assert((Size > 0) && "Constant array should have size greater than 0");
904
905  const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
906  RSExportType *ElementET = RSExportType::Create(Context, ElementType);
907
908  if (ElementET == NULL) {
909    fprintf(stderr, "RSExportConstantArrayType::Create : failed to create "
910                    "RSExportType for array element.\n");
911    return NULL;
912  }
913
914  return new RSExportConstantArrayType(Context,
915                                       ElementET,
916                                       Size);
917}
918
919const llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
920  return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
921}
922
923void RSExportConstantArrayType::keep() {
924  const_cast<RSExportType*>(mElementType)->keep();
925  RSExportType::keep();
926  return;
927}
928
929bool RSExportConstantArrayType::equals(const RSExportable *E) const {
930  CHECK_PARENT_EQUALITY(RSExportType, E);
931  return ((static_cast<const RSExportConstantArrayType*>(E)
932              ->getSize() == getSize()) && (mElementType->equals(E)));
933}
934
935/**************************** RSExportRecordType ****************************/
936RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
937                                               const clang::RecordType *RT,
938                                               const llvm::StringRef &TypeName,
939                                               bool mIsArtificial) {
940  assert(RT != NULL && RT->getTypeClass() == clang::Type::Record);
941
942  const clang::RecordDecl *RD = RT->getDecl();
943  assert(RD->isStruct());
944
945  RD = RD->getDefinition();
946  if (RD == NULL) {
947    // TODO(zonr): warn that actual struct definition isn't declared in this
948    //             moudle.
949    fprintf(stderr, "RSExportRecordType::Create : this struct is not defined "
950                    "in this module.");
951    return NULL;
952  }
953
954  // Struct layout construct by clang. We rely on this for obtaining the
955  // alloc size of a struct and offset of every field in that struct.
956  const clang::ASTRecordLayout *RL =
957      &Context->getASTContext()->getASTRecordLayout(RD);
958  assert((RL != NULL) && "Failed to retrieve the struct layout from Clang.");
959
960  RSExportRecordType *ERT =
961      new RSExportRecordType(Context,
962                             TypeName,
963                             RD->hasAttr<clang::PackedAttr>(),
964                             mIsArtificial,
965                             (RL->getSize() >> 3));
966  unsigned int Index = 0;
967
968  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
969           FE = RD->field_end();
970       FI != FE;
971       FI++, Index++) {
972#define FAILED_CREATE_FIELD(err)    do {         \
973      if (*err)                                                          \
974        fprintf(stderr, \
975                "RSExportRecordType::Create : failed to create field (%s)\n", \
976                err);                                                   \
977      delete ERT;                                                       \
978      return NULL;                                                      \
979    } while (false)
980
981    // FIXME: All fields should be primitive type
982    assert((*FI)->getKind() == clang::Decl::Field);
983    clang::FieldDecl *FD = *FI;
984
985    // We don't support bit field
986    //
987    // TODO(zonr): allow bitfield with size 8, 16, 32
988    if (FD->isBitField())
989      FAILED_CREATE_FIELD("bit field is not supported");
990
991    // Type
992    RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
993
994    if (ET != NULL)
995      ERT->mFields.push_back(
996          new Field(ET, FD->getName(), ERT,
997                    static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
998    else
999      FAILED_CREATE_FIELD(FD->getName().str().c_str());
1000#undef FAILED_CREATE_FIELD
1001  }
1002
1003  return ERT;
1004}
1005
1006const llvm::Type *RSExportRecordType::convertToLLVMType() const {
1007  std::vector<const llvm::Type*> FieldTypes;
1008
1009  for (const_field_iterator FI = fields_begin(),
1010           FE = fields_end();
1011       FI != FE;
1012       FI++) {
1013    const Field *F = *FI;
1014    const RSExportType *FET = F->getType();
1015
1016    FieldTypes.push_back(FET->getLLVMType());
1017  }
1018
1019  return llvm::StructType::get(getRSContext()->getLLVMContext(),
1020                               FieldTypes,
1021                               mIsPacked);
1022}
1023
1024void RSExportRecordType::keep() {
1025  for (std::list<const Field*>::iterator I = mFields.begin(),
1026          E = mFields.end();
1027       I != E;
1028       I++) {
1029    const_cast<RSExportType*>((*I)->getType())->keep();
1030  }
1031  RSExportType::keep();
1032  return;
1033}
1034
1035bool RSExportRecordType::equals(const RSExportable *E) const {
1036  CHECK_PARENT_EQUALITY(RSExportType, E);
1037
1038  const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1039
1040  if (ERT->getFields().size() != getFields().size())
1041    return false;
1042
1043  const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1044
1045  for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1046    if (!(*AI)->getType()->equals((*BI)->getType()))
1047      return false;
1048    AI++;
1049    BI++;
1050  }
1051
1052  return true;
1053}
1054