slang_rs_backend.cpp revision fcda2352b9e140529f8f3c89f05b10a70c0048b2
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <stack>
20#include <vector>
21#include <string>
22
23#include "llvm/Metadata.h"
24#include "llvm/Constant.h"
25#include "llvm/Constants.h"
26#include "llvm/Module.h"
27#include "llvm/Function.h"
28#include "llvm/DerivedTypes.h"
29
30#include "llvm/Support/IRBuilder.h"
31
32#include "llvm/ADT/Twine.h"
33#include "llvm/ADT/StringExtras.h"
34
35#include "clang/AST/DeclGroup.h"
36#include "clang/AST/Expr.h"
37#include "clang/AST/OperationKinds.h"
38#include "clang/AST/Stmt.h"
39#include "clang/AST/StmtVisitor.h"
40
41#include "slang_rs.h"
42#include "slang_rs_context.h"
43#include "slang_rs_metadata.h"
44#include "slang_rs_export_var.h"
45#include "slang_rs_export_func.h"
46#include "slang_rs_export_type.h"
47
48using namespace slang;
49
50RSBackend::RSBackend(RSContext *Context,
51                     clang::Diagnostic &Diags,
52                     const clang::CodeGenOptions &CodeGenOpts,
53                     const clang::TargetOptions &TargetOpts,
54                     const PragmaList &Pragmas,
55                     llvm::raw_ostream *OS,
56                     Slang::OutputType OT,
57                     clang::SourceManager &SourceMgr,
58                     bool AllowRSPrefix)
59    : Backend(Diags,
60              CodeGenOpts,
61              TargetOpts,
62              Pragmas,
63              OS,
64              OT),
65      mContext(Context),
66      mSourceMgr(SourceMgr),
67      mAllowRSPrefix(AllowRSPrefix),
68      mExportVarMetadata(NULL),
69      mExportFuncMetadata(NULL),
70      mExportTypeMetadata(NULL) {
71  return;
72}
73
74void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
75  // Disallow user-defined functions with prefix "rs"
76  if (!mAllowRSPrefix) {
77    // Iterate all function declarations in the program.
78    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
79         I != E; I++) {
80      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
81      if (FD == NULL)
82        continue;
83      if (!FD->getName().startswith("rs"))  // Check prefix
84        continue;
85      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
86        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
87                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
88                                             "invalid function name prefix, "
89                                             "\"rs\" is reserved: '%0'"))
90            << FD->getName();
91    }
92  }
93
94  Backend::HandleTopLevelDecl(D);
95  return;
96}
97///////////////////////////////////////////////////////////////////////////////
98
99namespace {
100
101  class RSObjectRefCounting : public clang::StmtVisitor<RSObjectRefCounting> {
102   private:
103    class Scope {
104     private:
105      clang::CompoundStmt *mCS;      // Associated compound statement ({ ... })
106      std::list<clang::Decl*> mRSO;  // Declared RS object in this scope
107
108     public:
109      Scope(clang::CompoundStmt *CS) : mCS(CS) {
110        return;
111      }
112
113      inline void addRSObject(clang::Decl* D) { mRSO.push_back(D); }
114    };
115    std::stack<Scope*> mScopeStack;
116
117    inline Scope *getCurrentScope() { return mScopeStack.top(); }
118
119    // Return false if the process was terminated early. I.e., Type of variable
120    // in VD is not an RS object type.
121    static bool InitializeRSObject(clang::VarDecl *VD);
122    static clang::Expr *CreateZeroInitializerForRSObject(
123        RSExportPrimitiveType::DataType DT,
124        clang::ASTContext &C,
125        const clang::SourceLocation &Loc);
126
127   public:
128    void VisitChildren(clang::Stmt *S) {
129      for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
130           I != E;
131           I++)
132        if (clang::Stmt *Child = *I)
133          Visit(Child);
134    }
135    void VisitStmt(clang::Stmt *S) { VisitChildren(S); }
136
137    void VisitDeclStmt(clang::DeclStmt *DS);
138    void VisitCompoundStmt(clang::CompoundStmt *CS);
139    void VisitBinAssign(clang::BinaryOperator *AS);
140
141    // We believe that RS objects never are involved in CompoundAssignOperator.
142    // I.e., rs_allocation foo; foo += bar;
143  };
144}
145
146bool RSObjectRefCounting::InitializeRSObject(clang::VarDecl *VD) {
147  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
148  RSExportPrimitiveType::DataType DT =
149      RSExportPrimitiveType::GetRSObjectType(T);
150
151  if (DT == RSExportPrimitiveType::DataTypeUnknown)
152    return false;
153
154  if (VD->hasInit()) {
155    // TODO: Update the reference count of RS object in initializer.
156    // This can potentially be done as part of the assignment pass.
157  } else {
158    clang::Expr *ZeroInitializer =
159        CreateZeroInitializerForRSObject(DT,
160                                         VD->getASTContext(),
161                                         VD->getLocation());
162
163    if (ZeroInitializer) {
164      ZeroInitializer->setType(T->getCanonicalTypeInternal());
165      VD->setInit(ZeroInitializer);
166    }
167  }
168
169  return true;
170}
171
172clang::Expr *RSObjectRefCounting::CreateZeroInitializerForRSObject(
173    RSExportPrimitiveType::DataType DT,
174    clang::ASTContext &C,
175    const clang::SourceLocation &Loc) {
176  clang::Expr *Res = NULL;
177  switch (DT) {
178    case RSExportPrimitiveType::DataTypeRSElement:
179    case RSExportPrimitiveType::DataTypeRSType:
180    case RSExportPrimitiveType::DataTypeRSAllocation:
181    case RSExportPrimitiveType::DataTypeRSSampler:
182    case RSExportPrimitiveType::DataTypeRSScript:
183    case RSExportPrimitiveType::DataTypeRSMesh:
184    case RSExportPrimitiveType::DataTypeRSProgramFragment:
185    case RSExportPrimitiveType::DataTypeRSProgramVertex:
186    case RSExportPrimitiveType::DataTypeRSProgramRaster:
187    case RSExportPrimitiveType::DataTypeRSProgramStore:
188    case RSExportPrimitiveType::DataTypeRSFont: {
189      //    (ImplicitCastExpr 'nullptr_t'
190      //      (IntegerLiteral 0)))
191      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
192      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
193      clang::Expr *CastToNull =
194          clang::ImplicitCastExpr::Create(C,
195                                          C.NullPtrTy,
196                                          clang::CK_IntegralToPointer,
197                                          Int0,
198                                          NULL,
199                                          clang::VK_RValue);
200
201      Res = new (C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
202      break;
203    }
204    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
205    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
206    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
207      // RS matrix is not completely an RS object. They hold data by themselves.
208      // (InitListExpr rs_matrix2x2
209      //   (InitListExpr float[4]
210      //     (FloatingLiteral 0)
211      //     (FloatingLiteral 0)
212      //     (FloatingLiteral 0)
213      //     (FloatingLiteral 0)))
214      clang::QualType FloatTy = C.FloatTy;
215      // Constructor sets value to 0.0f by default
216      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
217      clang::FloatingLiteral *Float0Val =
218          clang::FloatingLiteral::Create(C,
219                                         Val,
220                                         /* isExact = */true,
221                                         FloatTy,
222                                         Loc);
223
224      unsigned N;
225      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
226        N = 2;
227      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
228        N = 3;
229      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
230        N = 4;
231
232      // Directly allocate 16 elements instead of dynamically allocate N*N
233      clang::Expr *InitVals[16];
234      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
235        InitVals[i] = Float0Val;
236      clang::Expr *InitExpr =
237          new (C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
238      InitExpr->setType(C.getConstantArrayType(FloatTy,
239                                               llvm::APInt(32, 4),
240                                               clang::ArrayType::Normal,
241                                               /* EltTypeQuals = */0));
242
243      Res = new (C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
244      break;
245    }
246    case RSExportPrimitiveType::DataTypeUnknown:
247    case RSExportPrimitiveType::DataTypeFloat16:
248    case RSExportPrimitiveType::DataTypeFloat32:
249    case RSExportPrimitiveType::DataTypeFloat64:
250    case RSExportPrimitiveType::DataTypeSigned8:
251    case RSExportPrimitiveType::DataTypeSigned16:
252    case RSExportPrimitiveType::DataTypeSigned32:
253    case RSExportPrimitiveType::DataTypeSigned64:
254    case RSExportPrimitiveType::DataTypeUnsigned8:
255    case RSExportPrimitiveType::DataTypeUnsigned16:
256    case RSExportPrimitiveType::DataTypeUnsigned32:
257    case RSExportPrimitiveType::DataTypeUnsigned64:
258    case RSExportPrimitiveType::DataTypeBoolean:
259    case RSExportPrimitiveType::DataTypeUnsigned565:
260    case RSExportPrimitiveType::DataTypeUnsigned5551:
261    case RSExportPrimitiveType::DataTypeUnsigned4444:
262    case RSExportPrimitiveType::DataTypeMax: {
263      assert(false && "Not RS object type!");
264    }
265    // No default case will enable compiler detecting the missing cases
266  }
267
268  return Res;
269}
270
271void RSObjectRefCounting::VisitDeclStmt(clang::DeclStmt *DS) {
272  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
273       I != E;
274       I++) {
275    clang::Decl *D = *I;
276    if (D->getKind() == clang::Decl::Var) {
277      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
278      if (InitializeRSObject(VD))
279        getCurrentScope()->addRSObject(VD);
280    }
281  }
282  return;
283}
284
285void RSObjectRefCounting::VisitCompoundStmt(clang::CompoundStmt *CS) {
286  if (!CS->body_empty()) {
287    // Push a new scope
288    Scope *S = new Scope(CS);
289    mScopeStack.push(S);
290
291    VisitChildren(CS);
292
293    // Destroy the scope
294    // TODO: Update reference count of the RS object refenced by the
295    //       getCurrentScope().
296    assert((getCurrentScope() == S) && "Corrupted scope stack!");
297    mScopeStack.pop();
298    delete S;
299  }
300  return;
301}
302
303void RSObjectRefCounting::VisitBinAssign(clang::BinaryOperator *AS) {
304  // TODO: Update reference count
305  return;
306}
307
308void RSBackend::HandleTranslationUnitPre(clang::ASTContext& C) {
309  RSObjectRefCounting RSObjectRefCounter;
310  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
311
312  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
313          E = TUDecl->decls_end(); I != E; I++) {
314    if ((I->getKind() >= clang::Decl::firstFunction) &&
315        (I->getKind() <= clang::Decl::lastFunction)) {
316      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
317      if (FD->hasBody() && !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
318        RSObjectRefCounter.Visit( FD->getBody());
319    }
320  }
321
322  return;
323}
324
325///////////////////////////////////////////////////////////////////////////////
326void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
327  mContext->processExport();
328
329  // Dump export variable info
330  if (mContext->hasExportVar()) {
331    if (mExportVarMetadata == NULL)
332      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
333
334    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
335
336    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
337            E = mContext->export_vars_end();
338         I != E;
339         I++) {
340      const RSExportVar *EV = *I;
341      const RSExportType *ET = EV->getType();
342
343      // Variable name
344      ExportVarInfo.push_back(
345          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
346
347      // Type name
348      switch (ET->getClass()) {
349        case RSExportType::ExportClassPrimitive: {
350          ExportVarInfo.push_back(
351              llvm::MDString::get(
352                mLLVMContext, llvm::utostr_32(
353                  static_cast<const RSExportPrimitiveType*>(ET)->getType())));
354          break;
355        }
356        case RSExportType::ExportClassPointer: {
357          ExportVarInfo.push_back(
358              llvm::MDString::get(
359                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
360                  ->getPointeeType()->getName()).c_str()));
361          break;
362        }
363        case RSExportType::ExportClassMatrix: {
364          ExportVarInfo.push_back(
365              llvm::MDString::get(
366                mLLVMContext, llvm::utostr_32(
367                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
368                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
369          break;
370        }
371        case RSExportType::ExportClassVector:
372        case RSExportType::ExportClassConstantArray:
373        case RSExportType::ExportClassRecord: {
374          ExportVarInfo.push_back(
375              llvm::MDString::get(mLLVMContext,
376                EV->getType()->getName().c_str()));
377          break;
378        }
379      }
380
381      mExportVarMetadata->addOperand(
382          llvm::MDNode::get(mLLVMContext,
383                            ExportVarInfo.data(),
384                            ExportVarInfo.size()) );
385
386      ExportVarInfo.clear();
387    }
388  }
389
390  // Dump export function info
391  if (mContext->hasExportFunc()) {
392    if (mExportFuncMetadata == NULL)
393      mExportFuncMetadata =
394          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
395
396    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
397
398    for (RSContext::const_export_func_iterator
399            I = mContext->export_funcs_begin(),
400            E = mContext->export_funcs_end();
401         I != E;
402         I++) {
403      const RSExportFunc *EF = *I;
404
405      // Function name
406      if (!EF->hasParam()) {
407        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
408                                                     EF->getName().c_str()));
409      } else {
410        llvm::Function *F = M->getFunction(EF->getName());
411        llvm::Function *HelperFunction;
412        const std::string HelperFunctionName(".helper_" + EF->getName());
413
414        assert(F && "Function marked as exported disappeared in Bitcode");
415
416        // Create helper function
417        {
418          llvm::StructType *HelperFunctionParameterTy = NULL;
419
420          if (!F->getArgumentList().empty()) {
421            std::vector<const llvm::Type*> HelperFunctionParameterTys;
422            for (llvm::Function::arg_iterator AI = F->arg_begin(),
423                 AE = F->arg_end(); AI != AE; AI++)
424              HelperFunctionParameterTys.push_back(AI->getType());
425
426            HelperFunctionParameterTy =
427                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
428          }
429
430          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
431            fprintf(stderr, "Failed to export function %s: parameter type "
432                            "mismatch during creation of helper function.\n",
433                    EF->getName().c_str());
434
435            const RSExportRecordType *Expected = EF->getParamPacketType();
436            if (Expected) {
437              fprintf(stderr, "Expected:\n");
438              Expected->getLLVMType()->dump();
439            }
440            if (HelperFunctionParameterTy) {
441              fprintf(stderr, "Got:\n");
442              HelperFunctionParameterTy->dump();
443            }
444          }
445
446          std::vector<const llvm::Type*> Params;
447          if (HelperFunctionParameterTy) {
448            llvm::PointerType *HelperFunctionParameterTyP =
449                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
450            Params.push_back(HelperFunctionParameterTyP);
451          }
452
453          llvm::FunctionType * HelperFunctionType =
454              llvm::FunctionType::get(F->getReturnType(),
455                                      Params,
456                                      /* IsVarArgs = */false);
457
458          HelperFunction =
459              llvm::Function::Create(HelperFunctionType,
460                                     llvm::GlobalValue::ExternalLinkage,
461                                     HelperFunctionName,
462                                     M);
463
464          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
465          HelperFunction->setCallingConv(F->getCallingConv());
466
467          // Create helper function body
468          {
469            llvm::Argument *HelperFunctionParameter =
470                &(*HelperFunction->arg_begin());
471            llvm::BasicBlock *BB =
472                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
473            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
474            llvm::SmallVector<llvm::Value*, 6> Params;
475            llvm::Value *Idx[2];
476
477            Idx[0] =
478                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
479
480            // getelementptr and load instruction for all elements in
481            // parameter .p
482            for (size_t i = 0; i < EF->getNumParameters(); i++) {
483              // getelementptr
484              Idx[1] =
485                  llvm::ConstantInt::get(
486                      llvm::Type::getInt32Ty(mLLVMContext), i);
487              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
488                                                       Idx,
489                                                       Idx + 2);
490
491              // load
492              llvm::Value *V = IB->CreateLoad(Ptr);
493              Params.push_back(V);
494            }
495
496            // Call and pass the all elements as paramter to F
497            llvm::CallInst *CI = IB->CreateCall(F,
498                                                Params.data(),
499                                                Params.data() + Params.size());
500
501            CI->setCallingConv(F->getCallingConv());
502
503            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
504              IB->CreateRetVoid();
505            else
506              IB->CreateRet(CI);
507
508            delete IB;
509          }
510        }
511
512        ExportFuncInfo.push_back(
513            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
514      }
515
516      mExportFuncMetadata->addOperand(
517          llvm::MDNode::get(mLLVMContext,
518                            ExportFuncInfo.data(),
519                            ExportFuncInfo.size()));
520
521      ExportFuncInfo.clear();
522    }
523  }
524
525  // Dump export type info
526  if (mContext->hasExportType()) {
527    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
528
529    for (RSContext::const_export_type_iterator
530            I = mContext->export_types_begin(),
531            E = mContext->export_types_end();
532         I != E;
533         I++) {
534      // First, dump type name list to export
535      const RSExportType *ET = I->getValue();
536
537      ExportTypeInfo.clear();
538      // Type name
539      ExportTypeInfo.push_back(
540          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
541
542      if (ET->getClass() == RSExportType::ExportClassRecord) {
543        const RSExportRecordType *ERT =
544            static_cast<const RSExportRecordType*>(ET);
545
546        if (mExportTypeMetadata == NULL)
547          mExportTypeMetadata =
548              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
549
550        mExportTypeMetadata->addOperand(
551            llvm::MDNode::get(mLLVMContext,
552                              ExportTypeInfo.data(),
553                              ExportTypeInfo.size()));
554
555        // Now, export struct field information to %[struct name]
556        std::string StructInfoMetadataName("%");
557        StructInfoMetadataName.append(ET->getName());
558        llvm::NamedMDNode *StructInfoMetadata =
559            M->getOrInsertNamedMetadata(StructInfoMetadataName);
560        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
561
562        assert(StructInfoMetadata->getNumOperands() == 0 &&
563               "Metadata with same name was created before");
564        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
565                FE = ERT->fields_end();
566             FI != FE;
567             FI++) {
568          const RSExportRecordType::Field *F = *FI;
569
570          // 1. field name
571          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
572                                                  F->getName().c_str()));
573
574          // 2. field type name
575          FieldInfo.push_back(
576              llvm::MDString::get(mLLVMContext,
577                                  F->getType()->getName().c_str()));
578
579          // 3. field kind
580          switch (F->getType()->getClass()) {
581            case RSExportType::ExportClassPrimitive:
582            case RSExportType::ExportClassVector: {
583              const RSExportPrimitiveType *EPT =
584                  static_cast<const RSExportPrimitiveType*>(F->getType());
585              FieldInfo.push_back(
586                  llvm::MDString::get(mLLVMContext,
587                                      llvm::itostr(EPT->getKind())));
588              break;
589            }
590
591            default: {
592              FieldInfo.push_back(
593                  llvm::MDString::get(mLLVMContext,
594                                      llvm::itostr(
595                                        RSExportPrimitiveType::DataKindUser)));
596              break;
597            }
598          }
599
600          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
601                                                           FieldInfo.data(),
602                                                           FieldInfo.size()));
603
604          FieldInfo.clear();
605        }
606      }   // ET->getClass() == RSExportType::ExportClassRecord
607    }
608  }
609
610  return;
611}
612
613RSBackend::~RSBackend() {
614  return;
615}
616