slang_rs_foreach_lowering.cpp revision 88f21e16250d2e52a75607b7f0c396e1c2a34201
1/* 2 * Copyright 2015, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "slang_rs_foreach_lowering.h" 18 19#include "clang/AST/ASTContext.h" 20#include "llvm/Support/raw_ostream.h" 21#include "slang_rs_context.h" 22#include "slang_rs_export_foreach.h" 23 24namespace slang { 25 26namespace { 27 28const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsForEach"; 29const char KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS[] = "rsForEachWithOptions"; 30const char INTERNAL_LAUNCH_FUNCTION_NAME[] = 31 "_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation"; 32 33} // anonymous namespace 34 35RSForEachLowering::RSForEachLowering(RSContext* ctxt) 36 : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {} 37 38// Check if the passed-in expr references a kernel function in the following 39// pattern in the AST. 40// 41// ImplicitCastExpr 'void *' <BitCast> 42// `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay> 43// `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)' 44const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator( 45 clang::Expr* expr) { 46 clang::ImplicitCastExpr* ToVoidPtr = 47 clang::dyn_cast<clang::ImplicitCastExpr>(expr); 48 if (ToVoidPtr == nullptr) { 49 return nullptr; 50 } 51 52 clang::ImplicitCastExpr* Decay = 53 clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr()); 54 55 if (Decay == nullptr) { 56 return nullptr; 57 } 58 59 clang::DeclRefExpr* DRE = 60 clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr()); 61 62 if (DRE == nullptr) { 63 return nullptr; 64 } 65 66 const clang::FunctionDecl* FD = 67 clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl()); 68 69 if (FD == nullptr) { 70 return nullptr; 71 } 72 73 return FD; 74} 75 76// Checks if the call expression is a legal rsForEach call by looking for the 77// following pattern in the AST. On success, returns the first argument that is 78// a FunctionDecl of a kernel function. 79// 80// CallExpr 'void' 81// | 82// |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay> 83// | `-DeclRefExpr 'void (void *, ...)' 'rsForEach' 'void (void *, ...)' 84// | 85// |-ImplicitCastExpr 'void *' <BitCast> 86// | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay> 87// | `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)' 88// | 89// |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue> 90// | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation' 91// | 92// `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue> 93// `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation' 94const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall( 95 clang::CallExpr* CE, int* slot, bool* hasOptions) { 96 const clang::Decl* D = CE->getCalleeDecl(); 97 const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D); 98 99 if (FD == nullptr) { 100 return nullptr; 101 } 102 103 const clang::StringRef& funcName = FD->getName(); 104 105 if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) { 106 *hasOptions = false; 107 } else if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS)) { 108 *hasOptions = true; 109 } else { 110 return nullptr; 111 } 112 113 clang::Expr* arg0 = CE->getArg(0); 114 const clang::FunctionDecl* kernel = matchFunctionDesignator(arg0); 115 116 if (kernel == nullptr) { 117 mCtxt->ReportError(arg0->getExprLoc(), 118 "Invalid kernel launch call. " 119 "Expects a function designator for the first argument."); 120 return nullptr; 121 } 122 123 // Verifies that kernel is indeed a "kernel" function. 124 *slot = mCtxt->getForEachSlotNumber(kernel); 125 if (*slot == -1) { 126 mCtxt->ReportError(CE->getExprLoc(), "%0 applied to non kernel function %1") 127 << funcName << kernel->getName(); 128 return nullptr; 129 } 130 131 return kernel; 132} 133 134// Create an AST node for the declaration of rsForEachInternal 135clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() { 136 clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl(); 137 clang::SourceLocation Loc; 138 139 llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME); 140 clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR); 141 clang::DeclarationName N(&II); 142 143 clang::FunctionProtoType::ExtProtoInfo EPI; 144 145 const clang::QualType& AllocTy = mCtxt->getAllocationType(); 146 clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy); 147 148 clang::QualType T = mASTCtxt.getFunctionType( 149 mASTCtxt.VoidTy, // Return type 150 // Argument types: 151 { mASTCtxt.IntTy, // int slot 152 mASTCtxt.VoidPtrTy, // rs_script_call_t* launch_options 153 mASTCtxt.IntTy, // int numOutput 154 mASTCtxt.IntTy, // int numInputs 155 AllocPtrTy // rs_allocation* allocs 156 }, 157 EPI); 158 159 clang::FunctionDecl* FD = clang::FunctionDecl::Create( 160 mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern); 161 162 return FD; 163} 164 165// Create an expression like the following that references the rsForEachInternal to 166// replace the callee in the original call expression that references rsForEach. 167// 168// ImplicitCastExpr 'void (*)(int, rs_allocation, rs_allocation)' <FunctionToPointerDecay> 169// `-DeclRefExpr 'void' Function '_Z17rsForEachInternali13rs_allocationS_' 'void (int, rs_allocation, rs_allocation)' 170clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() { 171 clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl(); 172 173 clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create( 174 mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew, 175 false, clang::SourceLocation(), mASTCtxt.VoidTy, clang::VK_RValue); 176 177 const clang::QualType FDNewType = FDNew->getType(); 178 179 clang::Expr* calleeNew = clang::ImplicitCastExpr::Create( 180 mASTCtxt, mASTCtxt.getPointerType(FDNewType), 181 clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue); 182 183 return calleeNew; 184} 185 186// This visit method checks (via pattern matching) if the call expression is to 187// rsForEach, and the arguments satisfy the restrictions on the 188// rsForEach API. If so, replace the call with a rsForEachInternal call 189// with the first argument replaced by the slot number of the kernel function 190// referenced in the original first argument. 191// 192// See comments to the helper methods defined above for details. 193void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) { 194 int slot; 195 bool hasOptions; 196 const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE, &slot, &hasOptions); 197 if (kernel == nullptr) { 198 return; 199 } 200 201 slangAssert(slot >= 0); 202 203 const unsigned numArgsOrig = CE->getNumArgs(); 204 205 clang::QualType resultType = kernel->getReturnType().getCanonicalType(); 206 const unsigned numOutputsExpected = resultType->isVoidType() ? 0 : 1; 207 208 const unsigned numInputsExpected = RSExportForEach::getNumInputs(mCtxt->getTargetAPI(), kernel); 209 210 // Verifies that rsForEach takes the right number of input and output allocations. 211 // TODO: Check input/output allocation types match kernel function expectation. 212 const unsigned numAllocations = numArgsOrig - (hasOptions ? 2 : 1); 213 if (numInputsExpected + numOutputsExpected != numAllocations) { 214 mCtxt->ReportError( 215 CE->getExprLoc(), 216 "Number of input and output allocations unexpected for kernel function %0") 217 << kernel->getName(); 218 return; 219 } 220 221 clang::Expr* calleeNew = CreateCalleeExprForInternalForEach(); 222 CE->setCallee(calleeNew); 223 224 const clang::CanQualType IntTy = mASTCtxt.IntTy; 225 const unsigned IntTySize = mASTCtxt.getTypeSize(IntTy); 226 const llvm::APInt APIntSlot(IntTySize, slot); 227 const clang::Expr* arg0 = CE->getArg(0); 228 const clang::SourceLocation Loc(arg0->getLocStart()); 229 clang::Expr* IntSlotNum = 230 clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, IntTy, Loc); 231 CE->setArg(0, IntSlotNum); 232 233 /* 234 The last few arguments to rsForEach or rsForEachWithOptions are allocations. 235 Creates a new compound literal of an array initialized with those values, and 236 passes it to rsForEachInternal as the last (the 5th) argument. 237 238 For example, rsForEach(foo, ain1, ain2, aout) would be translated into 239 rsForEachInternal( 240 1, // Slot number for kernel 241 NULL, // Launch options 242 2, // Number of input allocations 243 1, // Number of output allocations 244 (rs_allocation[]){ain1, ain2, aout) // Input and output allocations 245 ); 246 247 The AST for the rs_allocation array looks like following: 248 249 ImplicitCastExpr 0x99575670 'struct rs_allocation *' <ArrayToPointerDecay> 250 `-CompoundLiteralExpr 0x99575648 'struct rs_allocation [3]' lvalue 251 `-InitListExpr 0x99575590 'struct rs_allocation [3]' 252 |-ImplicitCastExpr 0x99574b38 'rs_allocation':'struct rs_allocation' <LValueToRValue> 253 | `-DeclRefExpr 0x99574a08 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c408 'ain1' 'rs_allocation':'struct rs_allocation' 254 |-ImplicitCastExpr 0x99574b50 'rs_allocation':'struct rs_allocation' <LValueToRValue> 255 | `-DeclRefExpr 0x99574a30 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'ain2' 'rs_allocation':'struct rs_allocation' 256 `-ImplicitCastExpr 0x99574b68 'rs_allocation':'struct rs_allocation' <LValueToRValue> 257 `-DeclRefExpr 0x99574a58 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'aout' 'rs_allocation':'struct rs_allocation' 258 */ 259 260 const clang::QualType& AllocTy = mCtxt->getAllocationType(); 261 const llvm::APInt APIntNumAllocs(IntTySize, numAllocations); 262 clang::QualType AllocArrayTy = mASTCtxt.getConstantArrayType( 263 AllocTy, 264 APIntNumAllocs, 265 clang::ArrayType::ArraySizeModifier::Normal, 266 0 // index type qualifiers 267 ); 268 269 const int allocArgIndexEnd = numArgsOrig - 1; 270 int allocArgIndexStart = allocArgIndexEnd; 271 272 clang::Expr** args = CE->getArgs(); 273 274 clang::SourceLocation lparenloc; 275 clang::SourceLocation rparenloc; 276 277 if (numAllocations > 0) { 278 allocArgIndexStart = hasOptions ? 2 : 1; 279 lparenloc = args[allocArgIndexStart]->getExprLoc(); 280 rparenloc = args[allocArgIndexEnd]->getExprLoc(); 281 } 282 283 clang::InitListExpr* init = new (mASTCtxt) clang::InitListExpr( 284 mASTCtxt, 285 lparenloc, 286 llvm::ArrayRef<clang::Expr*>(args + allocArgIndexStart, numAllocations), 287 rparenloc); 288 init->setType(AllocArrayTy); 289 290 clang::TypeSourceInfo* ti = mASTCtxt.getTrivialTypeSourceInfo(AllocArrayTy); 291 clang::CompoundLiteralExpr* CLE = new (mASTCtxt) clang::CompoundLiteralExpr( 292 lparenloc, 293 ti, 294 AllocArrayTy, 295 clang::VK_LValue, // A compound literal is an l-value in C. 296 init, 297 false // Not file scope 298 ); 299 300 const clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy); 301 302 clang::ImplicitCastExpr* Decay = clang::ImplicitCastExpr::Create( 303 mASTCtxt, 304 AllocPtrTy, 305 clang::CK_ArrayToPointerDecay, 306 CLE, 307 nullptr, // C++ cast path 308 clang::VK_RValue 309 ); 310 311 CE->setNumArgs(mASTCtxt, 5); 312 313 CE->setArg(4, Decay); 314 315 // Sets the new arguments for NULL launch option (if the user does not set one), 316 // the number of outputs, and the number of inputs. 317 318 if (!hasOptions) { 319 const llvm::APInt APIntZero(IntTySize, 0); 320 clang::Expr* IntNull = 321 clang::IntegerLiteral::Create(mASTCtxt, APIntZero, IntTy, Loc); 322 CE->setArg(1, IntNull); 323 } 324 325 const llvm::APInt APIntNumOutput(IntTySize, numOutputsExpected); 326 clang::Expr* IntNumOutput = 327 clang::IntegerLiteral::Create(mASTCtxt, APIntNumOutput, IntTy, Loc); 328 CE->setArg(2, IntNumOutput); 329 330 const llvm::APInt APIntNumInputs(IntTySize, numInputsExpected); 331 clang::Expr* IntNumInputs = 332 clang::IntegerLiteral::Create(mASTCtxt, APIntNumInputs, IntTy, Loc); 333 CE->setArg(3, IntNumInputs); 334} 335 336void RSForEachLowering::VisitStmt(clang::Stmt* S) { 337 for (clang::Stmt* Child : S->children()) { 338 if (Child) { 339 Visit(Child); 340 } 341 } 342} 343 344} // namespace slang 345