slang_rs_foreach_lowering.cpp revision fb40ee2a90f37967bf4a40a18dec7f60e5c580d8
1/* 2 * Copyright 2015, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "slang_rs_foreach_lowering.h" 18 19#include "clang/AST/ASTContext.h" 20#include "llvm/Support/raw_ostream.h" 21#include "slang_rs_context.h" 22#include "slang_rs_export_foreach.h" 23 24namespace slang { 25 26namespace { 27 28const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsParallelFor"; 29const char INTERNAL_LAUNCH_FUNCTION_NAME[] = 30 "_Z17rsForEachInternali13rs_allocationS_"; 31 32} // anonymous namespace 33 34RSForEachLowering::RSForEachLowering(RSContext* ctxt) 35 : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {} 36 37// Check if the passed-in expr references a kernel function in the following 38// pattern in the AST. 39// 40// ImplicitCastExpr 'void *' <BitCast> 41// `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay> 42// `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)' 43const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator( 44 clang::Expr* expr) { 45 clang::ImplicitCastExpr* ToVoidPtr = 46 clang::dyn_cast<clang::ImplicitCastExpr>(expr); 47 if (ToVoidPtr == nullptr) { 48 return nullptr; 49 } 50 51 clang::ImplicitCastExpr* Decay = 52 clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr()); 53 54 if (Decay == nullptr) { 55 return nullptr; 56 } 57 58 clang::DeclRefExpr* DRE = 59 clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr()); 60 61 if (DRE == nullptr) { 62 return nullptr; 63 } 64 65 const clang::FunctionDecl* FD = 66 clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl()); 67 68 if (FD == nullptr) { 69 return nullptr; 70 } 71 72 // TODO: Verify the launch has the expected number of input allocations 73 74 return FD; 75} 76 77// Checks if the call expression is a legal rsParallelFor call by looking for the 78// following pattern in the AST. On success, returns the first argument that is 79// a FunctionDecl of a kernel function. 80// 81// CallExpr 'void' 82// | 83// |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay> 84// | `-DeclRefExpr 'void (void *, ...)' 'rsParallelFor' 'void (void *, ...)' 85// | 86// |-ImplicitCastExpr 'void *' <BitCast> 87// | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay> 88// | `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)' 89// | 90// |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue> 91// | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation' 92// | 93// `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue> 94// `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation' 95const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall( 96 clang::CallExpr* CE) { 97 const clang::Decl* D = CE->getCalleeDecl(); 98 const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D); 99 100 if (FD == nullptr) { 101 return nullptr; 102 } 103 104 const clang::StringRef& funcName = FD->getName(); 105 106 if (!funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) { 107 return nullptr; 108 } 109 110 const clang::FunctionDecl* kernel = matchFunctionDesignator(CE->getArg(0)); 111 112 if (kernel == nullptr || 113 CE->getNumArgs() < 3) { // TODO: Make argument check more accurate 114 mCtxt->ReportError(CE->getExprLoc(), "Invalid kernel launch call."); 115 } 116 117 return kernel; 118} 119 120// Create an AST node for the declaration of rsForEachInternal 121clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() { 122 const clang::QualType& AllocTy = mCtxt->getAllocationType(); 123 clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl(); 124 clang::SourceLocation Loc; 125 126 llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME); 127 clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR); 128 clang::DeclarationName N(&II); 129 130 clang::FunctionProtoType::ExtProtoInfo EPI; 131 132 clang::QualType T = mASTCtxt.getFunctionType( 133 mASTCtxt.VoidTy, // Return type 134 {mASTCtxt.IntTy, AllocTy, AllocTy}, // Argument types 135 EPI); 136 137 clang::FunctionDecl* FD = clang::FunctionDecl::Create( 138 mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern); 139 return FD; 140} 141 142// Create an expression like the following that references the rsForEachInternal to 143// replace the callee in the original call expression that references rsParallelFor. 144// 145// ImplicitCastExpr 'void (*)(int, rs_allocation, rs_allocation)' <FunctionToPointerDecay> 146// `-DeclRefExpr 'void' Function '_Z17rsForEachInternali13rs_allocationS_' 'void (int, rs_allocation, rs_allocation)' 147clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() { 148 clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl(); 149 150 clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create( 151 mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew, 152 false, clang::SourceLocation(), mASTCtxt.VoidTy, clang::VK_RValue); 153 154 const clang::QualType FDNewType = FDNew->getType(); 155 156 clang::Expr* calleeNew = clang::ImplicitCastExpr::Create( 157 mASTCtxt, mASTCtxt.getPointerType(FDNewType), 158 clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue); 159 160 return calleeNew; 161} 162 163// This visit method checks (via pattern matching) if the call expression is to 164// rsParallelFor, and the arguments satisfy the restrictions on the 165// rsParallelFor API. If so, replace the call with a rsForEachInternal call 166// with the first argument replaced by the slot number of the kernel function 167// referenced in the original first argument. 168// 169// See comments to the helper methods defined above for details. 170void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) { 171 const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE); 172 if (kernel == nullptr) { 173 return; 174 } 175 176 clang::Expr* calleeNew = CreateCalleeExprForInternalForEach(); 177 CE->setCallee(calleeNew); 178 179 const int slot = mCtxt->getForEachSlotNumber(kernel); 180 const llvm::APInt APIntSlot(mASTCtxt.getTypeSize(mASTCtxt.IntTy), slot); 181 const clang::Expr* arg0 = CE->getArg(0); 182 const clang::SourceLocation Loc(arg0->getLocStart()); 183 clang::Expr* IntSlotNum = 184 clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, mASTCtxt.IntTy, Loc); 185 CE->setArg(0, IntSlotNum); 186} 187 188void RSForEachLowering::VisitStmt(clang::Stmt* S) { 189 for (clang::Stmt* Child : S->children()) { 190 if (Child) { 191 Visit(Child); 192 } 193 } 194} 195 196} // namespace slang 197