slang_rs_backend.cpp revision b1771ef128b10c4d4575634828006bfba20b1d9c
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <stack>
20#include <vector>
21#include <string>
22
23#include "llvm/Metadata.h"
24#include "llvm/Constant.h"
25#include "llvm/Constants.h"
26#include "llvm/Module.h"
27#include "llvm/Function.h"
28#include "llvm/DerivedTypes.h"
29
30#include "llvm/Support/IRBuilder.h"
31
32#include "llvm/ADT/Twine.h"
33#include "llvm/ADT/StringExtras.h"
34
35#include "clang/AST/DeclGroup.h"
36#include "clang/AST/Expr.h"
37#include "clang/AST/OperationKinds.h"
38#include "clang/AST/Stmt.h"
39#include "clang/AST/StmtVisitor.h"
40
41#include "slang_rs.h"
42#include "slang_rs_context.h"
43#include "slang_rs_metadata.h"
44#include "slang_rs_export_var.h"
45#include "slang_rs_export_func.h"
46#include "slang_rs_export_type.h"
47
48using namespace slang;
49
50RSBackend::RSBackend(RSContext *Context,
51                     clang::Diagnostic &Diags,
52                     const clang::CodeGenOptions &CodeGenOpts,
53                     const clang::TargetOptions &TargetOpts,
54                     const PragmaList &Pragmas,
55                     llvm::raw_ostream *OS,
56                     Slang::OutputType OT,
57                     clang::SourceManager &SourceMgr,
58                     bool AllowRSPrefix)
59    : Backend(Diags,
60              CodeGenOpts,
61              TargetOpts,
62              Pragmas,
63              OS,
64              OT),
65      mContext(Context),
66      mSourceMgr(SourceMgr),
67      mAllowRSPrefix(AllowRSPrefix),
68      mExportVarMetadata(NULL),
69      mExportFuncMetadata(NULL),
70      mExportTypeMetadata(NULL) {
71  return;
72}
73
74void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
75  // Disallow user-defined functions with prefix "rs"
76  if (!mAllowRSPrefix) {
77    // Iterate all function declarations in the program.
78    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
79         I != E; I++) {
80      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
81      if (FD == NULL)
82        continue;
83      if (!FD->getName().startswith("rs"))  // Check prefix
84        continue;
85      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
86        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
87                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
88                                             "invalid function name prefix, "
89                                             "\"rs\" is reserved: '%0'"))
90            << FD->getName();
91    }
92  }
93
94  Backend::HandleTopLevelDecl(D);
95  return;
96}
97///////////////////////////////////////////////////////////////////////////////
98
99namespace {
100
101  class RSObjectRefCounting : public clang::StmtVisitor<RSObjectRefCounting> {
102   private:
103    class Scope {
104     private:
105      clang::CompoundStmt *mCS;      // Associated compound statement ({ ... })
106      std::list<clang::Decl*> mRSO;  // Declared RS object in this scope
107
108     public:
109      Scope(clang::CompoundStmt *CS) : mCS(CS) {
110        return;
111      }
112
113      inline void addRSObject(clang::Decl* D) { mRSO.push_back(D); }
114    };
115    std::stack<Scope*> mScopeStack;
116
117    inline Scope *getCurrentScope() { return mScopeStack.top(); }
118
119    // Return false if the type of variable declared in VD is not an RS object
120    // type.
121    static bool InitializeRSObject(clang::VarDecl *VD);
122    // Return an zero-initializer expr of the type DT. This processes both
123    // RS matrix type and RS object type.
124    static clang::Expr *CreateZeroInitializerForRSSpecificType(
125        RSExportPrimitiveType::DataType DT,
126        clang::ASTContext &C,
127        const clang::SourceLocation &Loc);
128
129   public:
130    void VisitChildren(clang::Stmt *S) {
131      for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
132           I != E;
133           I++)
134        if (clang::Stmt *Child = *I)
135          Visit(Child);
136    }
137    void VisitStmt(clang::Stmt *S) { VisitChildren(S); }
138
139    void VisitDeclStmt(clang::DeclStmt *DS);
140    void VisitCompoundStmt(clang::CompoundStmt *CS);
141    void VisitBinAssign(clang::BinaryOperator *AS);
142
143    // We believe that RS objects never are involved in CompoundAssignOperator.
144    // I.e., rs_allocation foo; foo += bar;
145  };
146}
147
148bool RSObjectRefCounting::InitializeRSObject(clang::VarDecl *VD) {
149  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
150  RSExportPrimitiveType::DataType DT =
151      RSExportPrimitiveType::GetRSSpecificType(T);
152
153  if (DT == RSExportPrimitiveType::DataTypeUnknown)
154    return false;
155
156  if (VD->hasInit()) {
157    // TODO: Update the reference count of RS object in initializer.
158    // This can potentially be done as part of the assignment pass.
159  } else {
160    clang::Expr *ZeroInitializer =
161        CreateZeroInitializerForRSSpecificType(DT,
162                                               VD->getASTContext(),
163                                               VD->getLocation());
164
165    if (ZeroInitializer) {
166      ZeroInitializer->setType(T->getCanonicalTypeInternal());
167      VD->setInit(ZeroInitializer);
168    }
169  }
170
171  return RSExportPrimitiveType::IsRSObjectType(DT);
172}
173
174clang::Expr *RSObjectRefCounting::CreateZeroInitializerForRSSpecificType(
175    RSExportPrimitiveType::DataType DT,
176    clang::ASTContext &C,
177    const clang::SourceLocation &Loc) {
178  clang::Expr *Res = NULL;
179  switch (DT) {
180    case RSExportPrimitiveType::DataTypeRSElement:
181    case RSExportPrimitiveType::DataTypeRSType:
182    case RSExportPrimitiveType::DataTypeRSAllocation:
183    case RSExportPrimitiveType::DataTypeRSSampler:
184    case RSExportPrimitiveType::DataTypeRSScript:
185    case RSExportPrimitiveType::DataTypeRSMesh:
186    case RSExportPrimitiveType::DataTypeRSProgramFragment:
187    case RSExportPrimitiveType::DataTypeRSProgramVertex:
188    case RSExportPrimitiveType::DataTypeRSProgramRaster:
189    case RSExportPrimitiveType::DataTypeRSProgramStore:
190    case RSExportPrimitiveType::DataTypeRSFont: {
191      //    (ImplicitCastExpr 'nullptr_t'
192      //      (IntegerLiteral 0)))
193      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
194      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
195      clang::Expr *CastToNull =
196          clang::ImplicitCastExpr::Create(C,
197                                          C.NullPtrTy,
198                                          clang::CK_IntegralToPointer,
199                                          Int0,
200                                          NULL,
201                                          clang::VK_RValue);
202
203      Res = new (C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
204      break;
205    }
206    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
207    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
208    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
209      // RS matrix is not completely an RS object. They hold data by themselves.
210      // (InitListExpr rs_matrix2x2
211      //   (InitListExpr float[4]
212      //     (FloatingLiteral 0)
213      //     (FloatingLiteral 0)
214      //     (FloatingLiteral 0)
215      //     (FloatingLiteral 0)))
216      clang::QualType FloatTy = C.FloatTy;
217      // Constructor sets value to 0.0f by default
218      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
219      clang::FloatingLiteral *Float0Val =
220          clang::FloatingLiteral::Create(C,
221                                         Val,
222                                         /* isExact = */true,
223                                         FloatTy,
224                                         Loc);
225
226      unsigned N = 0;
227      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
228        N = 2;
229      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
230        N = 3;
231      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
232        N = 4;
233
234      // Directly allocate 16 elements instead of dynamically allocate N*N
235      clang::Expr *InitVals[16];
236      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
237        InitVals[i] = Float0Val;
238      clang::Expr *InitExpr =
239          new (C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
240      InitExpr->setType(C.getConstantArrayType(FloatTy,
241                                               llvm::APInt(32, 4),
242                                               clang::ArrayType::Normal,
243                                               /* EltTypeQuals = */0));
244
245      Res = new (C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
246      break;
247    }
248    case RSExportPrimitiveType::DataTypeUnknown:
249    case RSExportPrimitiveType::DataTypeFloat16:
250    case RSExportPrimitiveType::DataTypeFloat32:
251    case RSExportPrimitiveType::DataTypeFloat64:
252    case RSExportPrimitiveType::DataTypeSigned8:
253    case RSExportPrimitiveType::DataTypeSigned16:
254    case RSExportPrimitiveType::DataTypeSigned32:
255    case RSExportPrimitiveType::DataTypeSigned64:
256    case RSExportPrimitiveType::DataTypeUnsigned8:
257    case RSExportPrimitiveType::DataTypeUnsigned16:
258    case RSExportPrimitiveType::DataTypeUnsigned32:
259    case RSExportPrimitiveType::DataTypeUnsigned64:
260    case RSExportPrimitiveType::DataTypeBoolean:
261    case RSExportPrimitiveType::DataTypeUnsigned565:
262    case RSExportPrimitiveType::DataTypeUnsigned5551:
263    case RSExportPrimitiveType::DataTypeUnsigned4444:
264    case RSExportPrimitiveType::DataTypeMax: {
265      assert(false && "Not RS object type!");
266    }
267    // No default case will enable compiler detecting the missing cases
268  }
269
270  return Res;
271}
272
273void RSObjectRefCounting::VisitDeclStmt(clang::DeclStmt *DS) {
274  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
275       I != E;
276       I++) {
277    clang::Decl *D = *I;
278    if (D->getKind() == clang::Decl::Var) {
279      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
280      if (InitializeRSObject(VD))
281        getCurrentScope()->addRSObject(VD);
282    }
283  }
284  return;
285}
286
287void RSObjectRefCounting::VisitCompoundStmt(clang::CompoundStmt *CS) {
288  if (!CS->body_empty()) {
289    // Push a new scope
290    Scope *S = new Scope(CS);
291    mScopeStack.push(S);
292
293    VisitChildren(CS);
294
295    // Destroy the scope
296    // TODO: Update reference count of the RS object refenced by the
297    //       getCurrentScope().
298    assert((getCurrentScope() == S) && "Corrupted scope stack!");
299    mScopeStack.pop();
300    delete S;
301  }
302  return;
303}
304
305void RSObjectRefCounting::VisitBinAssign(clang::BinaryOperator *AS) {
306  // TODO: Update reference count
307  return;
308}
309
310void RSBackend::HandleTranslationUnitPre(clang::ASTContext& C) {
311  RSObjectRefCounting RSObjectRefCounter;
312  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
313
314  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
315          E = TUDecl->decls_end(); I != E; I++) {
316    if ((I->getKind() >= clang::Decl::firstFunction) &&
317        (I->getKind() <= clang::Decl::lastFunction)) {
318      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
319      if (FD->hasBody() && !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
320        RSObjectRefCounter.Visit( FD->getBody());
321    }
322  }
323
324  return;
325}
326
327///////////////////////////////////////////////////////////////////////////////
328void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
329  mContext->processExport();
330
331  // Dump export variable info
332  if (mContext->hasExportVar()) {
333    if (mExportVarMetadata == NULL)
334      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
335
336    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
337
338    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
339            E = mContext->export_vars_end();
340         I != E;
341         I++) {
342      const RSExportVar *EV = *I;
343      const RSExportType *ET = EV->getType();
344
345      // Variable name
346      ExportVarInfo.push_back(
347          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
348
349      // Type name
350      switch (ET->getClass()) {
351        case RSExportType::ExportClassPrimitive: {
352          ExportVarInfo.push_back(
353              llvm::MDString::get(
354                mLLVMContext, llvm::utostr_32(
355                  static_cast<const RSExportPrimitiveType*>(ET)->getType())));
356          break;
357        }
358        case RSExportType::ExportClassPointer: {
359          ExportVarInfo.push_back(
360              llvm::MDString::get(
361                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
362                  ->getPointeeType()->getName()).c_str()));
363          break;
364        }
365        case RSExportType::ExportClassMatrix: {
366          ExportVarInfo.push_back(
367              llvm::MDString::get(
368                mLLVMContext, llvm::utostr_32(
369                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
370                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
371          break;
372        }
373        case RSExportType::ExportClassVector:
374        case RSExportType::ExportClassConstantArray:
375        case RSExportType::ExportClassRecord: {
376          ExportVarInfo.push_back(
377              llvm::MDString::get(mLLVMContext,
378                EV->getType()->getName().c_str()));
379          break;
380        }
381      }
382
383      mExportVarMetadata->addOperand(
384          llvm::MDNode::get(mLLVMContext,
385                            ExportVarInfo.data(),
386                            ExportVarInfo.size()) );
387
388      ExportVarInfo.clear();
389    }
390  }
391
392  // Dump export function info
393  if (mContext->hasExportFunc()) {
394    if (mExportFuncMetadata == NULL)
395      mExportFuncMetadata =
396          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
397
398    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
399
400    for (RSContext::const_export_func_iterator
401            I = mContext->export_funcs_begin(),
402            E = mContext->export_funcs_end();
403         I != E;
404         I++) {
405      const RSExportFunc *EF = *I;
406
407      // Function name
408      if (!EF->hasParam()) {
409        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
410                                                     EF->getName().c_str()));
411      } else {
412        llvm::Function *F = M->getFunction(EF->getName());
413        llvm::Function *HelperFunction;
414        const std::string HelperFunctionName(".helper_" + EF->getName());
415
416        assert(F && "Function marked as exported disappeared in Bitcode");
417
418        // Create helper function
419        {
420          llvm::StructType *HelperFunctionParameterTy = NULL;
421
422          if (!F->getArgumentList().empty()) {
423            std::vector<const llvm::Type*> HelperFunctionParameterTys;
424            for (llvm::Function::arg_iterator AI = F->arg_begin(),
425                 AE = F->arg_end(); AI != AE; AI++)
426              HelperFunctionParameterTys.push_back(AI->getType());
427
428            HelperFunctionParameterTy =
429                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
430          }
431
432          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
433            fprintf(stderr, "Failed to export function %s: parameter type "
434                            "mismatch during creation of helper function.\n",
435                    EF->getName().c_str());
436
437            const RSExportRecordType *Expected = EF->getParamPacketType();
438            if (Expected) {
439              fprintf(stderr, "Expected:\n");
440              Expected->getLLVMType()->dump();
441            }
442            if (HelperFunctionParameterTy) {
443              fprintf(stderr, "Got:\n");
444              HelperFunctionParameterTy->dump();
445            }
446          }
447
448          std::vector<const llvm::Type*> Params;
449          if (HelperFunctionParameterTy) {
450            llvm::PointerType *HelperFunctionParameterTyP =
451                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
452            Params.push_back(HelperFunctionParameterTyP);
453          }
454
455          llvm::FunctionType * HelperFunctionType =
456              llvm::FunctionType::get(F->getReturnType(),
457                                      Params,
458                                      /* IsVarArgs = */false);
459
460          HelperFunction =
461              llvm::Function::Create(HelperFunctionType,
462                                     llvm::GlobalValue::ExternalLinkage,
463                                     HelperFunctionName,
464                                     M);
465
466          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
467          HelperFunction->setCallingConv(F->getCallingConv());
468
469          // Create helper function body
470          {
471            llvm::Argument *HelperFunctionParameter =
472                &(*HelperFunction->arg_begin());
473            llvm::BasicBlock *BB =
474                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
475            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
476            llvm::SmallVector<llvm::Value*, 6> Params;
477            llvm::Value *Idx[2];
478
479            Idx[0] =
480                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
481
482            // getelementptr and load instruction for all elements in
483            // parameter .p
484            for (size_t i = 0; i < EF->getNumParameters(); i++) {
485              // getelementptr
486              Idx[1] =
487                  llvm::ConstantInt::get(
488                      llvm::Type::getInt32Ty(mLLVMContext), i);
489              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
490                                                       Idx,
491                                                       Idx + 2);
492
493              // load
494              llvm::Value *V = IB->CreateLoad(Ptr);
495              Params.push_back(V);
496            }
497
498            // Call and pass the all elements as paramter to F
499            llvm::CallInst *CI = IB->CreateCall(F,
500                                                Params.data(),
501                                                Params.data() + Params.size());
502
503            CI->setCallingConv(F->getCallingConv());
504
505            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
506              IB->CreateRetVoid();
507            else
508              IB->CreateRet(CI);
509
510            delete IB;
511          }
512        }
513
514        ExportFuncInfo.push_back(
515            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
516      }
517
518      mExportFuncMetadata->addOperand(
519          llvm::MDNode::get(mLLVMContext,
520                            ExportFuncInfo.data(),
521                            ExportFuncInfo.size()));
522
523      ExportFuncInfo.clear();
524    }
525  }
526
527  // Dump export type info
528  if (mContext->hasExportType()) {
529    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
530
531    for (RSContext::const_export_type_iterator
532            I = mContext->export_types_begin(),
533            E = mContext->export_types_end();
534         I != E;
535         I++) {
536      // First, dump type name list to export
537      const RSExportType *ET = I->getValue();
538
539      ExportTypeInfo.clear();
540      // Type name
541      ExportTypeInfo.push_back(
542          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
543
544      if (ET->getClass() == RSExportType::ExportClassRecord) {
545        const RSExportRecordType *ERT =
546            static_cast<const RSExportRecordType*>(ET);
547
548        if (mExportTypeMetadata == NULL)
549          mExportTypeMetadata =
550              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
551
552        mExportTypeMetadata->addOperand(
553            llvm::MDNode::get(mLLVMContext,
554                              ExportTypeInfo.data(),
555                              ExportTypeInfo.size()));
556
557        // Now, export struct field information to %[struct name]
558        std::string StructInfoMetadataName("%");
559        StructInfoMetadataName.append(ET->getName());
560        llvm::NamedMDNode *StructInfoMetadata =
561            M->getOrInsertNamedMetadata(StructInfoMetadataName);
562        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
563
564        assert(StructInfoMetadata->getNumOperands() == 0 &&
565               "Metadata with same name was created before");
566        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
567                FE = ERT->fields_end();
568             FI != FE;
569             FI++) {
570          const RSExportRecordType::Field *F = *FI;
571
572          // 1. field name
573          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
574                                                  F->getName().c_str()));
575
576          // 2. field type name
577          FieldInfo.push_back(
578              llvm::MDString::get(mLLVMContext,
579                                  F->getType()->getName().c_str()));
580
581          // 3. field kind
582          switch (F->getType()->getClass()) {
583            case RSExportType::ExportClassPrimitive:
584            case RSExportType::ExportClassVector: {
585              const RSExportPrimitiveType *EPT =
586                  static_cast<const RSExportPrimitiveType*>(F->getType());
587              FieldInfo.push_back(
588                  llvm::MDString::get(mLLVMContext,
589                                      llvm::itostr(EPT->getKind())));
590              break;
591            }
592
593            default: {
594              FieldInfo.push_back(
595                  llvm::MDString::get(mLLVMContext,
596                                      llvm::itostr(
597                                        RSExportPrimitiveType::DataKindUser)));
598              break;
599            }
600          }
601
602          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
603                                                           FieldInfo.data(),
604                                                           FieldInfo.size()));
605
606          FieldInfo.clear();
607        }
608      }   // ET->getClass() == RSExportType::ExportClassRecord
609    }
610  }
611
612  return;
613}
614
615RSBackend::~RSBackend() {
616  return;
617}
618