slang_rs_backend.cpp revision b1771ef128b10c4d4575634828006bfba20b1d9c
1/* 2 * Copyright 2010, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "slang_rs_backend.h" 18 19#include <stack> 20#include <vector> 21#include <string> 22 23#include "llvm/Metadata.h" 24#include "llvm/Constant.h" 25#include "llvm/Constants.h" 26#include "llvm/Module.h" 27#include "llvm/Function.h" 28#include "llvm/DerivedTypes.h" 29 30#include "llvm/Support/IRBuilder.h" 31 32#include "llvm/ADT/Twine.h" 33#include "llvm/ADT/StringExtras.h" 34 35#include "clang/AST/DeclGroup.h" 36#include "clang/AST/Expr.h" 37#include "clang/AST/OperationKinds.h" 38#include "clang/AST/Stmt.h" 39#include "clang/AST/StmtVisitor.h" 40 41#include "slang_rs.h" 42#include "slang_rs_context.h" 43#include "slang_rs_metadata.h" 44#include "slang_rs_export_var.h" 45#include "slang_rs_export_func.h" 46#include "slang_rs_export_type.h" 47 48using namespace slang; 49 50RSBackend::RSBackend(RSContext *Context, 51 clang::Diagnostic &Diags, 52 const clang::CodeGenOptions &CodeGenOpts, 53 const clang::TargetOptions &TargetOpts, 54 const PragmaList &Pragmas, 55 llvm::raw_ostream *OS, 56 Slang::OutputType OT, 57 clang::SourceManager &SourceMgr, 58 bool AllowRSPrefix) 59 : Backend(Diags, 60 CodeGenOpts, 61 TargetOpts, 62 Pragmas, 63 OS, 64 OT), 65 mContext(Context), 66 mSourceMgr(SourceMgr), 67 mAllowRSPrefix(AllowRSPrefix), 68 mExportVarMetadata(NULL), 69 mExportFuncMetadata(NULL), 70 mExportTypeMetadata(NULL) { 71 return; 72} 73 74void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) { 75 // Disallow user-defined functions with prefix "rs" 76 if (!mAllowRSPrefix) { 77 // Iterate all function declarations in the program. 78 for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); 79 I != E; I++) { 80 clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I); 81 if (FD == NULL) 82 continue; 83 if (!FD->getName().startswith("rs")) // Check prefix 84 continue; 85 if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) 86 mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr), 87 mDiags.getCustomDiagID(clang::Diagnostic::Error, 88 "invalid function name prefix, " 89 "\"rs\" is reserved: '%0'")) 90 << FD->getName(); 91 } 92 } 93 94 Backend::HandleTopLevelDecl(D); 95 return; 96} 97/////////////////////////////////////////////////////////////////////////////// 98 99namespace { 100 101 class RSObjectRefCounting : public clang::StmtVisitor<RSObjectRefCounting> { 102 private: 103 class Scope { 104 private: 105 clang::CompoundStmt *mCS; // Associated compound statement ({ ... }) 106 std::list<clang::Decl*> mRSO; // Declared RS object in this scope 107 108 public: 109 Scope(clang::CompoundStmt *CS) : mCS(CS) { 110 return; 111 } 112 113 inline void addRSObject(clang::Decl* D) { mRSO.push_back(D); } 114 }; 115 std::stack<Scope*> mScopeStack; 116 117 inline Scope *getCurrentScope() { return mScopeStack.top(); } 118 119 // Return false if the type of variable declared in VD is not an RS object 120 // type. 121 static bool InitializeRSObject(clang::VarDecl *VD); 122 // Return an zero-initializer expr of the type DT. This processes both 123 // RS matrix type and RS object type. 124 static clang::Expr *CreateZeroInitializerForRSSpecificType( 125 RSExportPrimitiveType::DataType DT, 126 clang::ASTContext &C, 127 const clang::SourceLocation &Loc); 128 129 public: 130 void VisitChildren(clang::Stmt *S) { 131 for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end(); 132 I != E; 133 I++) 134 if (clang::Stmt *Child = *I) 135 Visit(Child); 136 } 137 void VisitStmt(clang::Stmt *S) { VisitChildren(S); } 138 139 void VisitDeclStmt(clang::DeclStmt *DS); 140 void VisitCompoundStmt(clang::CompoundStmt *CS); 141 void VisitBinAssign(clang::BinaryOperator *AS); 142 143 // We believe that RS objects never are involved in CompoundAssignOperator. 144 // I.e., rs_allocation foo; foo += bar; 145 }; 146} 147 148bool RSObjectRefCounting::InitializeRSObject(clang::VarDecl *VD) { 149 const clang::Type *T = RSExportType::GetTypeOfDecl(VD); 150 RSExportPrimitiveType::DataType DT = 151 RSExportPrimitiveType::GetRSSpecificType(T); 152 153 if (DT == RSExportPrimitiveType::DataTypeUnknown) 154 return false; 155 156 if (VD->hasInit()) { 157 // TODO: Update the reference count of RS object in initializer. 158 // This can potentially be done as part of the assignment pass. 159 } else { 160 clang::Expr *ZeroInitializer = 161 CreateZeroInitializerForRSSpecificType(DT, 162 VD->getASTContext(), 163 VD->getLocation()); 164 165 if (ZeroInitializer) { 166 ZeroInitializer->setType(T->getCanonicalTypeInternal()); 167 VD->setInit(ZeroInitializer); 168 } 169 } 170 171 return RSExportPrimitiveType::IsRSObjectType(DT); 172} 173 174clang::Expr *RSObjectRefCounting::CreateZeroInitializerForRSSpecificType( 175 RSExportPrimitiveType::DataType DT, 176 clang::ASTContext &C, 177 const clang::SourceLocation &Loc) { 178 clang::Expr *Res = NULL; 179 switch (DT) { 180 case RSExportPrimitiveType::DataTypeRSElement: 181 case RSExportPrimitiveType::DataTypeRSType: 182 case RSExportPrimitiveType::DataTypeRSAllocation: 183 case RSExportPrimitiveType::DataTypeRSSampler: 184 case RSExportPrimitiveType::DataTypeRSScript: 185 case RSExportPrimitiveType::DataTypeRSMesh: 186 case RSExportPrimitiveType::DataTypeRSProgramFragment: 187 case RSExportPrimitiveType::DataTypeRSProgramVertex: 188 case RSExportPrimitiveType::DataTypeRSProgramRaster: 189 case RSExportPrimitiveType::DataTypeRSProgramStore: 190 case RSExportPrimitiveType::DataTypeRSFont: { 191 // (ImplicitCastExpr 'nullptr_t' 192 // (IntegerLiteral 0))) 193 llvm::APInt Zero(C.getTypeSize(C.IntTy), 0); 194 clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc); 195 clang::Expr *CastToNull = 196 clang::ImplicitCastExpr::Create(C, 197 C.NullPtrTy, 198 clang::CK_IntegralToPointer, 199 Int0, 200 NULL, 201 clang::VK_RValue); 202 203 Res = new (C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc); 204 break; 205 } 206 case RSExportPrimitiveType::DataTypeRSMatrix2x2: 207 case RSExportPrimitiveType::DataTypeRSMatrix3x3: 208 case RSExportPrimitiveType::DataTypeRSMatrix4x4: { 209 // RS matrix is not completely an RS object. They hold data by themselves. 210 // (InitListExpr rs_matrix2x2 211 // (InitListExpr float[4] 212 // (FloatingLiteral 0) 213 // (FloatingLiteral 0) 214 // (FloatingLiteral 0) 215 // (FloatingLiteral 0))) 216 clang::QualType FloatTy = C.FloatTy; 217 // Constructor sets value to 0.0f by default 218 llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy)); 219 clang::FloatingLiteral *Float0Val = 220 clang::FloatingLiteral::Create(C, 221 Val, 222 /* isExact = */true, 223 FloatTy, 224 Loc); 225 226 unsigned N = 0; 227 if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2) 228 N = 2; 229 else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3) 230 N = 3; 231 else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4) 232 N = 4; 233 234 // Directly allocate 16 elements instead of dynamically allocate N*N 235 clang::Expr *InitVals[16]; 236 for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++) 237 InitVals[i] = Float0Val; 238 clang::Expr *InitExpr = 239 new (C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc); 240 InitExpr->setType(C.getConstantArrayType(FloatTy, 241 llvm::APInt(32, 4), 242 clang::ArrayType::Normal, 243 /* EltTypeQuals = */0)); 244 245 Res = new (C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc); 246 break; 247 } 248 case RSExportPrimitiveType::DataTypeUnknown: 249 case RSExportPrimitiveType::DataTypeFloat16: 250 case RSExportPrimitiveType::DataTypeFloat32: 251 case RSExportPrimitiveType::DataTypeFloat64: 252 case RSExportPrimitiveType::DataTypeSigned8: 253 case RSExportPrimitiveType::DataTypeSigned16: 254 case RSExportPrimitiveType::DataTypeSigned32: 255 case RSExportPrimitiveType::DataTypeSigned64: 256 case RSExportPrimitiveType::DataTypeUnsigned8: 257 case RSExportPrimitiveType::DataTypeUnsigned16: 258 case RSExportPrimitiveType::DataTypeUnsigned32: 259 case RSExportPrimitiveType::DataTypeUnsigned64: 260 case RSExportPrimitiveType::DataTypeBoolean: 261 case RSExportPrimitiveType::DataTypeUnsigned565: 262 case RSExportPrimitiveType::DataTypeUnsigned5551: 263 case RSExportPrimitiveType::DataTypeUnsigned4444: 264 case RSExportPrimitiveType::DataTypeMax: { 265 assert(false && "Not RS object type!"); 266 } 267 // No default case will enable compiler detecting the missing cases 268 } 269 270 return Res; 271} 272 273void RSObjectRefCounting::VisitDeclStmt(clang::DeclStmt *DS) { 274 for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end(); 275 I != E; 276 I++) { 277 clang::Decl *D = *I; 278 if (D->getKind() == clang::Decl::Var) { 279 clang::VarDecl *VD = static_cast<clang::VarDecl*>(D); 280 if (InitializeRSObject(VD)) 281 getCurrentScope()->addRSObject(VD); 282 } 283 } 284 return; 285} 286 287void RSObjectRefCounting::VisitCompoundStmt(clang::CompoundStmt *CS) { 288 if (!CS->body_empty()) { 289 // Push a new scope 290 Scope *S = new Scope(CS); 291 mScopeStack.push(S); 292 293 VisitChildren(CS); 294 295 // Destroy the scope 296 // TODO: Update reference count of the RS object refenced by the 297 // getCurrentScope(). 298 assert((getCurrentScope() == S) && "Corrupted scope stack!"); 299 mScopeStack.pop(); 300 delete S; 301 } 302 return; 303} 304 305void RSObjectRefCounting::VisitBinAssign(clang::BinaryOperator *AS) { 306 // TODO: Update reference count 307 return; 308} 309 310void RSBackend::HandleTranslationUnitPre(clang::ASTContext& C) { 311 RSObjectRefCounting RSObjectRefCounter; 312 clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl(); 313 314 for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(), 315 E = TUDecl->decls_end(); I != E; I++) { 316 if ((I->getKind() >= clang::Decl::firstFunction) && 317 (I->getKind() <= clang::Decl::lastFunction)) { 318 clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I); 319 if (FD->hasBody() && !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) 320 RSObjectRefCounter.Visit( FD->getBody()); 321 } 322 } 323 324 return; 325} 326 327/////////////////////////////////////////////////////////////////////////////// 328void RSBackend::HandleTranslationUnitPost(llvm::Module *M) { 329 mContext->processExport(); 330 331 // Dump export variable info 332 if (mContext->hasExportVar()) { 333 if (mExportVarMetadata == NULL) 334 mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN); 335 336 llvm::SmallVector<llvm::Value*, 2> ExportVarInfo; 337 338 for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(), 339 E = mContext->export_vars_end(); 340 I != E; 341 I++) { 342 const RSExportVar *EV = *I; 343 const RSExportType *ET = EV->getType(); 344 345 // Variable name 346 ExportVarInfo.push_back( 347 llvm::MDString::get(mLLVMContext, EV->getName().c_str())); 348 349 // Type name 350 switch (ET->getClass()) { 351 case RSExportType::ExportClassPrimitive: { 352 ExportVarInfo.push_back( 353 llvm::MDString::get( 354 mLLVMContext, llvm::utostr_32( 355 static_cast<const RSExportPrimitiveType*>(ET)->getType()))); 356 break; 357 } 358 case RSExportType::ExportClassPointer: { 359 ExportVarInfo.push_back( 360 llvm::MDString::get( 361 mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET) 362 ->getPointeeType()->getName()).c_str())); 363 break; 364 } 365 case RSExportType::ExportClassMatrix: { 366 ExportVarInfo.push_back( 367 llvm::MDString::get( 368 mLLVMContext, llvm::utostr_32( 369 RSExportPrimitiveType::DataTypeRSMatrix2x2 + 370 static_cast<const RSExportMatrixType*>(ET)->getDim() - 2))); 371 break; 372 } 373 case RSExportType::ExportClassVector: 374 case RSExportType::ExportClassConstantArray: 375 case RSExportType::ExportClassRecord: { 376 ExportVarInfo.push_back( 377 llvm::MDString::get(mLLVMContext, 378 EV->getType()->getName().c_str())); 379 break; 380 } 381 } 382 383 mExportVarMetadata->addOperand( 384 llvm::MDNode::get(mLLVMContext, 385 ExportVarInfo.data(), 386 ExportVarInfo.size()) ); 387 388 ExportVarInfo.clear(); 389 } 390 } 391 392 // Dump export function info 393 if (mContext->hasExportFunc()) { 394 if (mExportFuncMetadata == NULL) 395 mExportFuncMetadata = 396 M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN); 397 398 llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo; 399 400 for (RSContext::const_export_func_iterator 401 I = mContext->export_funcs_begin(), 402 E = mContext->export_funcs_end(); 403 I != E; 404 I++) { 405 const RSExportFunc *EF = *I; 406 407 // Function name 408 if (!EF->hasParam()) { 409 ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext, 410 EF->getName().c_str())); 411 } else { 412 llvm::Function *F = M->getFunction(EF->getName()); 413 llvm::Function *HelperFunction; 414 const std::string HelperFunctionName(".helper_" + EF->getName()); 415 416 assert(F && "Function marked as exported disappeared in Bitcode"); 417 418 // Create helper function 419 { 420 llvm::StructType *HelperFunctionParameterTy = NULL; 421 422 if (!F->getArgumentList().empty()) { 423 std::vector<const llvm::Type*> HelperFunctionParameterTys; 424 for (llvm::Function::arg_iterator AI = F->arg_begin(), 425 AE = F->arg_end(); AI != AE; AI++) 426 HelperFunctionParameterTys.push_back(AI->getType()); 427 428 HelperFunctionParameterTy = 429 llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys); 430 } 431 432 if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) { 433 fprintf(stderr, "Failed to export function %s: parameter type " 434 "mismatch during creation of helper function.\n", 435 EF->getName().c_str()); 436 437 const RSExportRecordType *Expected = EF->getParamPacketType(); 438 if (Expected) { 439 fprintf(stderr, "Expected:\n"); 440 Expected->getLLVMType()->dump(); 441 } 442 if (HelperFunctionParameterTy) { 443 fprintf(stderr, "Got:\n"); 444 HelperFunctionParameterTy->dump(); 445 } 446 } 447 448 std::vector<const llvm::Type*> Params; 449 if (HelperFunctionParameterTy) { 450 llvm::PointerType *HelperFunctionParameterTyP = 451 llvm::PointerType::getUnqual(HelperFunctionParameterTy); 452 Params.push_back(HelperFunctionParameterTyP); 453 } 454 455 llvm::FunctionType * HelperFunctionType = 456 llvm::FunctionType::get(F->getReturnType(), 457 Params, 458 /* IsVarArgs = */false); 459 460 HelperFunction = 461 llvm::Function::Create(HelperFunctionType, 462 llvm::GlobalValue::ExternalLinkage, 463 HelperFunctionName, 464 M); 465 466 HelperFunction->addFnAttr(llvm::Attribute::NoInline); 467 HelperFunction->setCallingConv(F->getCallingConv()); 468 469 // Create helper function body 470 { 471 llvm::Argument *HelperFunctionParameter = 472 &(*HelperFunction->arg_begin()); 473 llvm::BasicBlock *BB = 474 llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction); 475 llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB); 476 llvm::SmallVector<llvm::Value*, 6> Params; 477 llvm::Value *Idx[2]; 478 479 Idx[0] = 480 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0); 481 482 // getelementptr and load instruction for all elements in 483 // parameter .p 484 for (size_t i = 0; i < EF->getNumParameters(); i++) { 485 // getelementptr 486 Idx[1] = 487 llvm::ConstantInt::get( 488 llvm::Type::getInt32Ty(mLLVMContext), i); 489 llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter, 490 Idx, 491 Idx + 2); 492 493 // load 494 llvm::Value *V = IB->CreateLoad(Ptr); 495 Params.push_back(V); 496 } 497 498 // Call and pass the all elements as paramter to F 499 llvm::CallInst *CI = IB->CreateCall(F, 500 Params.data(), 501 Params.data() + Params.size()); 502 503 CI->setCallingConv(F->getCallingConv()); 504 505 if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext)) 506 IB->CreateRetVoid(); 507 else 508 IB->CreateRet(CI); 509 510 delete IB; 511 } 512 } 513 514 ExportFuncInfo.push_back( 515 llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str())); 516 } 517 518 mExportFuncMetadata->addOperand( 519 llvm::MDNode::get(mLLVMContext, 520 ExportFuncInfo.data(), 521 ExportFuncInfo.size())); 522 523 ExportFuncInfo.clear(); 524 } 525 } 526 527 // Dump export type info 528 if (mContext->hasExportType()) { 529 llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo; 530 531 for (RSContext::const_export_type_iterator 532 I = mContext->export_types_begin(), 533 E = mContext->export_types_end(); 534 I != E; 535 I++) { 536 // First, dump type name list to export 537 const RSExportType *ET = I->getValue(); 538 539 ExportTypeInfo.clear(); 540 // Type name 541 ExportTypeInfo.push_back( 542 llvm::MDString::get(mLLVMContext, ET->getName().c_str())); 543 544 if (ET->getClass() == RSExportType::ExportClassRecord) { 545 const RSExportRecordType *ERT = 546 static_cast<const RSExportRecordType*>(ET); 547 548 if (mExportTypeMetadata == NULL) 549 mExportTypeMetadata = 550 M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN); 551 552 mExportTypeMetadata->addOperand( 553 llvm::MDNode::get(mLLVMContext, 554 ExportTypeInfo.data(), 555 ExportTypeInfo.size())); 556 557 // Now, export struct field information to %[struct name] 558 std::string StructInfoMetadataName("%"); 559 StructInfoMetadataName.append(ET->getName()); 560 llvm::NamedMDNode *StructInfoMetadata = 561 M->getOrInsertNamedMetadata(StructInfoMetadataName); 562 llvm::SmallVector<llvm::Value*, 3> FieldInfo; 563 564 assert(StructInfoMetadata->getNumOperands() == 0 && 565 "Metadata with same name was created before"); 566 for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(), 567 FE = ERT->fields_end(); 568 FI != FE; 569 FI++) { 570 const RSExportRecordType::Field *F = *FI; 571 572 // 1. field name 573 FieldInfo.push_back(llvm::MDString::get(mLLVMContext, 574 F->getName().c_str())); 575 576 // 2. field type name 577 FieldInfo.push_back( 578 llvm::MDString::get(mLLVMContext, 579 F->getType()->getName().c_str())); 580 581 // 3. field kind 582 switch (F->getType()->getClass()) { 583 case RSExportType::ExportClassPrimitive: 584 case RSExportType::ExportClassVector: { 585 const RSExportPrimitiveType *EPT = 586 static_cast<const RSExportPrimitiveType*>(F->getType()); 587 FieldInfo.push_back( 588 llvm::MDString::get(mLLVMContext, 589 llvm::itostr(EPT->getKind()))); 590 break; 591 } 592 593 default: { 594 FieldInfo.push_back( 595 llvm::MDString::get(mLLVMContext, 596 llvm::itostr( 597 RSExportPrimitiveType::DataKindUser))); 598 break; 599 } 600 } 601 602 StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext, 603 FieldInfo.data(), 604 FieldInfo.size())); 605 606 FieldInfo.clear(); 607 } 608 } // ET->getClass() == RSExportType::ExportClassRecord 609 } 610 } 611 612 return; 613} 614 615RSBackend::~RSBackend() { 616 return; 617} 618