slang_rs_backend.cpp revision cfae0f350e4e0d8f7b4c71780b3f74f58fa23afb
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <stack>
20#include <vector>
21#include <string>
22
23#include "llvm/Metadata.h"
24#include "llvm/Constant.h"
25#include "llvm/Constants.h"
26#include "llvm/Module.h"
27#include "llvm/Function.h"
28#include "llvm/DerivedTypes.h"
29
30#include "llvm/Support/IRBuilder.h"
31
32#include "llvm/ADT/Twine.h"
33#include "llvm/ADT/StringExtras.h"
34
35#include "clang/AST/DeclGroup.h"
36#include "clang/AST/Expr.h"
37#include "clang/AST/OperationKinds.h"
38#include "clang/AST/Stmt.h"
39#include "clang/AST/StmtVisitor.h"
40
41#include "slang_rs.h"
42#include "slang_rs_context.h"
43#include "slang_rs_metadata.h"
44#include "slang_rs_export_var.h"
45#include "slang_rs_export_func.h"
46#include "slang_rs_export_type.h"
47
48using namespace slang;
49
50RSBackend::RSBackend(RSContext *Context,
51                     clang::Diagnostic &Diags,
52                     const clang::CodeGenOptions &CodeGenOpts,
53                     const clang::TargetOptions &TargetOpts,
54                     const PragmaList &Pragmas,
55                     llvm::raw_ostream *OS,
56                     Slang::OutputType OT,
57                     clang::SourceManager &SourceMgr,
58                     bool AllowRSPrefix)
59    : Backend(Diags,
60              CodeGenOpts,
61              TargetOpts,
62              Pragmas,
63              OS,
64              OT),
65      mContext(Context),
66      mSourceMgr(SourceMgr),
67      mAllowRSPrefix(AllowRSPrefix),
68      mExportVarMetadata(NULL),
69      mExportFuncMetadata(NULL),
70      mExportTypeMetadata(NULL) {
71  return;
72}
73
74///////////////////////////////////////////////////////////////////////////////
75
76namespace {
77
78  class RSObjectRefCounting : public clang::StmtVisitor<RSObjectRefCounting> {
79   private:
80    class Scope {
81     private:
82      clang::CompoundStmt *mCS;      // Associated compound statement ({ ... })
83      std::list<clang::Decl*> mRSO;  // Declared RS object in this scope
84
85     public:
86      Scope(clang::CompoundStmt *CS) : mCS(CS) {
87        return;
88      }
89
90      inline void addRSObject(clang::Decl* D) { mRSO.push_back(D); }
91    };
92    std::stack<Scope*> mScopeStack;
93
94    inline Scope *getCurrentScope() { return mScopeStack.top(); }
95
96    // Return false if the type of variable declared in VD is not an RS object
97    // type.
98    static bool InitializeRSObject(clang::VarDecl *VD);
99    // Return an zero-initializer expr of the type DT. This processes both
100    // RS matrix type and RS object type.
101    static clang::Expr *CreateZeroInitializerForRSSpecificType(
102        RSExportPrimitiveType::DataType DT,
103        clang::ASTContext &C,
104        const clang::SourceLocation &Loc);
105
106   public:
107    void VisitChildren(clang::Stmt *S) {
108      for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
109           I != E;
110           I++)
111        if (clang::Stmt *Child = *I)
112          Visit(Child);
113    }
114    void VisitStmt(clang::Stmt *S) { VisitChildren(S); }
115
116    void VisitDeclStmt(clang::DeclStmt *DS);
117    void VisitCompoundStmt(clang::CompoundStmt *CS);
118    void VisitBinAssign(clang::BinaryOperator *AS);
119
120    // We believe that RS objects never are involved in CompoundAssignOperator.
121    // I.e., rs_allocation foo; foo += bar;
122  };
123}
124
125bool RSObjectRefCounting::InitializeRSObject(clang::VarDecl *VD) {
126  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
127  RSExportPrimitiveType::DataType DT =
128      RSExportPrimitiveType::GetRSSpecificType(T);
129
130  if (DT == RSExportPrimitiveType::DataTypeUnknown)
131    return false;
132
133  if (VD->hasInit()) {
134    // TODO: Update the reference count of RS object in initializer.
135    // This can potentially be done as part of the assignment pass.
136  } else {
137    clang::Expr *ZeroInitializer =
138        CreateZeroInitializerForRSSpecificType(DT,
139                                               VD->getASTContext(),
140                                               VD->getLocation());
141
142    if (ZeroInitializer) {
143      ZeroInitializer->setType(T->getCanonicalTypeInternal());
144      VD->setInit(ZeroInitializer);
145    }
146  }
147
148  return RSExportPrimitiveType::IsRSObjectType(DT);
149}
150
151clang::Expr *RSObjectRefCounting::CreateZeroInitializerForRSSpecificType(
152    RSExportPrimitiveType::DataType DT,
153    clang::ASTContext &C,
154    const clang::SourceLocation &Loc) {
155  clang::Expr *Res = NULL;
156  switch (DT) {
157    case RSExportPrimitiveType::DataTypeRSElement:
158    case RSExportPrimitiveType::DataTypeRSType:
159    case RSExportPrimitiveType::DataTypeRSAllocation:
160    case RSExportPrimitiveType::DataTypeRSSampler:
161    case RSExportPrimitiveType::DataTypeRSScript:
162    case RSExportPrimitiveType::DataTypeRSMesh:
163    case RSExportPrimitiveType::DataTypeRSProgramFragment:
164    case RSExportPrimitiveType::DataTypeRSProgramVertex:
165    case RSExportPrimitiveType::DataTypeRSProgramRaster:
166    case RSExportPrimitiveType::DataTypeRSProgramStore:
167    case RSExportPrimitiveType::DataTypeRSFont: {
168      //    (ImplicitCastExpr 'nullptr_t'
169      //      (IntegerLiteral 0)))
170      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
171      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
172      clang::Expr *CastToNull =
173          clang::ImplicitCastExpr::Create(C,
174                                          C.NullPtrTy,
175                                          clang::CK_IntegralToPointer,
176                                          Int0,
177                                          NULL,
178                                          clang::VK_RValue);
179
180      Res = new (C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
181      break;
182    }
183    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
184    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
185    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
186      // RS matrix is not completely an RS object. They hold data by themselves.
187      // (InitListExpr rs_matrix2x2
188      //   (InitListExpr float[4]
189      //     (FloatingLiteral 0)
190      //     (FloatingLiteral 0)
191      //     (FloatingLiteral 0)
192      //     (FloatingLiteral 0)))
193      clang::QualType FloatTy = C.FloatTy;
194      // Constructor sets value to 0.0f by default
195      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
196      clang::FloatingLiteral *Float0Val =
197          clang::FloatingLiteral::Create(C,
198                                         Val,
199                                         /* isExact = */true,
200                                         FloatTy,
201                                         Loc);
202
203      unsigned N = 0;
204      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
205        N = 2;
206      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
207        N = 3;
208      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
209        N = 4;
210
211      // Directly allocate 16 elements instead of dynamically allocate N*N
212      clang::Expr *InitVals[16];
213      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
214        InitVals[i] = Float0Val;
215      clang::Expr *InitExpr =
216          new (C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
217      InitExpr->setType(C.getConstantArrayType(FloatTy,
218                                               llvm::APInt(32, 4),
219                                               clang::ArrayType::Normal,
220                                               /* EltTypeQuals = */0));
221
222      Res = new (C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
223      break;
224    }
225    case RSExportPrimitiveType::DataTypeUnknown:
226    case RSExportPrimitiveType::DataTypeFloat16:
227    case RSExportPrimitiveType::DataTypeFloat32:
228    case RSExportPrimitiveType::DataTypeFloat64:
229    case RSExportPrimitiveType::DataTypeSigned8:
230    case RSExportPrimitiveType::DataTypeSigned16:
231    case RSExportPrimitiveType::DataTypeSigned32:
232    case RSExportPrimitiveType::DataTypeSigned64:
233    case RSExportPrimitiveType::DataTypeUnsigned8:
234    case RSExportPrimitiveType::DataTypeUnsigned16:
235    case RSExportPrimitiveType::DataTypeUnsigned32:
236    case RSExportPrimitiveType::DataTypeUnsigned64:
237    case RSExportPrimitiveType::DataTypeBoolean:
238    case RSExportPrimitiveType::DataTypeUnsigned565:
239    case RSExportPrimitiveType::DataTypeUnsigned5551:
240    case RSExportPrimitiveType::DataTypeUnsigned4444:
241    case RSExportPrimitiveType::DataTypeMax: {
242      assert(false && "Not RS object type!");
243    }
244    // No default case will enable compiler detecting the missing cases
245  }
246
247  return Res;
248}
249
250void RSObjectRefCounting::VisitDeclStmt(clang::DeclStmt *DS) {
251  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
252       I != E;
253       I++) {
254    clang::Decl *D = *I;
255    if (D->getKind() == clang::Decl::Var) {
256      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
257      if (InitializeRSObject(VD))
258        getCurrentScope()->addRSObject(VD);
259    }
260  }
261  return;
262}
263
264void RSObjectRefCounting::VisitCompoundStmt(clang::CompoundStmt *CS) {
265  if (!CS->body_empty()) {
266    // Push a new scope
267    Scope *S = new Scope(CS);
268    mScopeStack.push(S);
269
270    VisitChildren(CS);
271
272    // Destroy the scope
273    // TODO: Update reference count of the RS object refenced by the
274    //       getCurrentScope().
275    assert((getCurrentScope() == S) && "Corrupted scope stack!");
276    mScopeStack.pop();
277    delete S;
278  }
279  return;
280}
281
282void RSObjectRefCounting::VisitBinAssign(clang::BinaryOperator *AS) {
283  // TODO: Update reference count
284  return;
285}
286
287// 1) Add zero initialization of local RS object types
288void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
289  if (FD &&
290      FD->hasBody() &&
291      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
292    RSObjectRefCounting RSObjectRefCounter;
293    RSObjectRefCounter.Visit(FD->getBody());
294  }
295  return;
296}
297
298void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
299  // Disallow user-defined functions with prefix "rs"
300  if (!mAllowRSPrefix) {
301    // Iterate all function declarations in the program.
302    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
303         I != E; I++) {
304      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
305      if (FD == NULL)
306        continue;
307      if (!FD->getName().startswith("rs"))  // Check prefix
308        continue;
309      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
310        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
311                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
312                                             "invalid function name prefix, "
313                                             "\"rs\" is reserved: '%0'"))
314            << FD->getName();
315    }
316  }
317
318  // Process any non-static function declarations
319  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
320    AnnotateFunction(dyn_cast<clang::FunctionDecl>(*I));
321  }
322
323  Backend::HandleTopLevelDecl(D);
324  return;
325}
326
327void RSBackend::HandleTranslationUnitPre(clang::ASTContext& C) {
328  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
329
330  // Process any static function declarations
331  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
332          E = TUDecl->decls_end(); I != E; I++) {
333    if ((I->getKind() >= clang::Decl::firstFunction) &&
334        (I->getKind() <= clang::Decl::lastFunction)) {
335      AnnotateFunction(static_cast<clang::FunctionDecl*>(*I));
336    }
337  }
338
339  return;
340}
341
342///////////////////////////////////////////////////////////////////////////////
343void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
344  mContext->processExport();
345
346  // Dump export variable info
347  if (mContext->hasExportVar()) {
348    if (mExportVarMetadata == NULL)
349      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
350
351    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
352
353    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
354            E = mContext->export_vars_end();
355         I != E;
356         I++) {
357      const RSExportVar *EV = *I;
358      const RSExportType *ET = EV->getType();
359
360      // Variable name
361      ExportVarInfo.push_back(
362          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
363
364      // Type name
365      switch (ET->getClass()) {
366        case RSExportType::ExportClassPrimitive: {
367          ExportVarInfo.push_back(
368              llvm::MDString::get(
369                mLLVMContext, llvm::utostr_32(
370                  static_cast<const RSExportPrimitiveType*>(ET)->getType())));
371          break;
372        }
373        case RSExportType::ExportClassPointer: {
374          ExportVarInfo.push_back(
375              llvm::MDString::get(
376                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
377                  ->getPointeeType()->getName()).c_str()));
378          break;
379        }
380        case RSExportType::ExportClassMatrix: {
381          ExportVarInfo.push_back(
382              llvm::MDString::get(
383                mLLVMContext, llvm::utostr_32(
384                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
385                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
386          break;
387        }
388        case RSExportType::ExportClassVector:
389        case RSExportType::ExportClassConstantArray:
390        case RSExportType::ExportClassRecord: {
391          ExportVarInfo.push_back(
392              llvm::MDString::get(mLLVMContext,
393                EV->getType()->getName().c_str()));
394          break;
395        }
396      }
397
398      mExportVarMetadata->addOperand(
399          llvm::MDNode::get(mLLVMContext,
400                            ExportVarInfo.data(),
401                            ExportVarInfo.size()) );
402
403      ExportVarInfo.clear();
404    }
405  }
406
407  // Dump export function info
408  if (mContext->hasExportFunc()) {
409    if (mExportFuncMetadata == NULL)
410      mExportFuncMetadata =
411          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
412
413    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
414
415    for (RSContext::const_export_func_iterator
416            I = mContext->export_funcs_begin(),
417            E = mContext->export_funcs_end();
418         I != E;
419         I++) {
420      const RSExportFunc *EF = *I;
421
422      // Function name
423      if (!EF->hasParam()) {
424        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
425                                                     EF->getName().c_str()));
426      } else {
427        llvm::Function *F = M->getFunction(EF->getName());
428        llvm::Function *HelperFunction;
429        const std::string HelperFunctionName(".helper_" + EF->getName());
430
431        assert(F && "Function marked as exported disappeared in Bitcode");
432
433        // Create helper function
434        {
435          llvm::StructType *HelperFunctionParameterTy = NULL;
436
437          if (!F->getArgumentList().empty()) {
438            std::vector<const llvm::Type*> HelperFunctionParameterTys;
439            for (llvm::Function::arg_iterator AI = F->arg_begin(),
440                 AE = F->arg_end(); AI != AE; AI++)
441              HelperFunctionParameterTys.push_back(AI->getType());
442
443            HelperFunctionParameterTy =
444                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
445          }
446
447          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
448            fprintf(stderr, "Failed to export function %s: parameter type "
449                            "mismatch during creation of helper function.\n",
450                    EF->getName().c_str());
451
452            const RSExportRecordType *Expected = EF->getParamPacketType();
453            if (Expected) {
454              fprintf(stderr, "Expected:\n");
455              Expected->getLLVMType()->dump();
456            }
457            if (HelperFunctionParameterTy) {
458              fprintf(stderr, "Got:\n");
459              HelperFunctionParameterTy->dump();
460            }
461          }
462
463          std::vector<const llvm::Type*> Params;
464          if (HelperFunctionParameterTy) {
465            llvm::PointerType *HelperFunctionParameterTyP =
466                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
467            Params.push_back(HelperFunctionParameterTyP);
468          }
469
470          llvm::FunctionType * HelperFunctionType =
471              llvm::FunctionType::get(F->getReturnType(),
472                                      Params,
473                                      /* IsVarArgs = */false);
474
475          HelperFunction =
476              llvm::Function::Create(HelperFunctionType,
477                                     llvm::GlobalValue::ExternalLinkage,
478                                     HelperFunctionName,
479                                     M);
480
481          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
482          HelperFunction->setCallingConv(F->getCallingConv());
483
484          // Create helper function body
485          {
486            llvm::Argument *HelperFunctionParameter =
487                &(*HelperFunction->arg_begin());
488            llvm::BasicBlock *BB =
489                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
490            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
491            llvm::SmallVector<llvm::Value*, 6> Params;
492            llvm::Value *Idx[2];
493
494            Idx[0] =
495                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
496
497            // getelementptr and load instruction for all elements in
498            // parameter .p
499            for (size_t i = 0; i < EF->getNumParameters(); i++) {
500              // getelementptr
501              Idx[1] =
502                  llvm::ConstantInt::get(
503                      llvm::Type::getInt32Ty(mLLVMContext), i);
504              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
505                                                       Idx,
506                                                       Idx + 2);
507
508              // load
509              llvm::Value *V = IB->CreateLoad(Ptr);
510              Params.push_back(V);
511            }
512
513            // Call and pass the all elements as paramter to F
514            llvm::CallInst *CI = IB->CreateCall(F,
515                                                Params.data(),
516                                                Params.data() + Params.size());
517
518            CI->setCallingConv(F->getCallingConv());
519
520            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
521              IB->CreateRetVoid();
522            else
523              IB->CreateRet(CI);
524
525            delete IB;
526          }
527        }
528
529        ExportFuncInfo.push_back(
530            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
531      }
532
533      mExportFuncMetadata->addOperand(
534          llvm::MDNode::get(mLLVMContext,
535                            ExportFuncInfo.data(),
536                            ExportFuncInfo.size()));
537
538      ExportFuncInfo.clear();
539    }
540  }
541
542  // Dump export type info
543  if (mContext->hasExportType()) {
544    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
545
546    for (RSContext::const_export_type_iterator
547            I = mContext->export_types_begin(),
548            E = mContext->export_types_end();
549         I != E;
550         I++) {
551      // First, dump type name list to export
552      const RSExportType *ET = I->getValue();
553
554      ExportTypeInfo.clear();
555      // Type name
556      ExportTypeInfo.push_back(
557          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
558
559      if (ET->getClass() == RSExportType::ExportClassRecord) {
560        const RSExportRecordType *ERT =
561            static_cast<const RSExportRecordType*>(ET);
562
563        if (mExportTypeMetadata == NULL)
564          mExportTypeMetadata =
565              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
566
567        mExportTypeMetadata->addOperand(
568            llvm::MDNode::get(mLLVMContext,
569                              ExportTypeInfo.data(),
570                              ExportTypeInfo.size()));
571
572        // Now, export struct field information to %[struct name]
573        std::string StructInfoMetadataName("%");
574        StructInfoMetadataName.append(ET->getName());
575        llvm::NamedMDNode *StructInfoMetadata =
576            M->getOrInsertNamedMetadata(StructInfoMetadataName);
577        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
578
579        assert(StructInfoMetadata->getNumOperands() == 0 &&
580               "Metadata with same name was created before");
581        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
582                FE = ERT->fields_end();
583             FI != FE;
584             FI++) {
585          const RSExportRecordType::Field *F = *FI;
586
587          // 1. field name
588          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
589                                                  F->getName().c_str()));
590
591          // 2. field type name
592          FieldInfo.push_back(
593              llvm::MDString::get(mLLVMContext,
594                                  F->getType()->getName().c_str()));
595
596          // 3. field kind
597          switch (F->getType()->getClass()) {
598            case RSExportType::ExportClassPrimitive:
599            case RSExportType::ExportClassVector: {
600              const RSExportPrimitiveType *EPT =
601                  static_cast<const RSExportPrimitiveType*>(F->getType());
602              FieldInfo.push_back(
603                  llvm::MDString::get(mLLVMContext,
604                                      llvm::itostr(EPT->getKind())));
605              break;
606            }
607
608            default: {
609              FieldInfo.push_back(
610                  llvm::MDString::get(mLLVMContext,
611                                      llvm::itostr(
612                                        RSExportPrimitiveType::DataKindUser)));
613              break;
614            }
615          }
616
617          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
618                                                           FieldInfo.data(),
619                                                           FieldInfo.size()));
620
621          FieldInfo.clear();
622        }
623      }   // ET->getClass() == RSExportType::ExportClassRecord
624    }
625  }
626
627  return;
628}
629
630RSBackend::~RSBackend() {
631  return;
632}
633