slang_rs_foreach_lowering.cpp revision 2615f383dfc1542a05f19aee23b03a09bd018f4e
1/* 2 * Copyright 2015, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "slang_rs_foreach_lowering.h" 18 19#include "clang/AST/ASTContext.h" 20#include "llvm/Support/raw_ostream.h" 21#include "slang_rs_context.h" 22#include "slang_rs_export_foreach.h" 23 24namespace slang { 25 26namespace { 27 28const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsForEach"; 29const char KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS[] = "rsForEachWithOptions"; 30const char INTERNAL_LAUNCH_FUNCTION_NAME[] = 31 "_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation"; 32 33} // anonymous namespace 34 35RSForEachLowering::RSForEachLowering(RSContext* ctxt) 36 : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {} 37 38// Check if the passed-in expr references a kernel function in the following 39// pattern in the AST. 40// 41// ImplicitCastExpr 'void *' <BitCast> 42// `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay> 43// `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)' 44const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator( 45 clang::Expr* expr) { 46 clang::ImplicitCastExpr* ToVoidPtr = 47 clang::dyn_cast<clang::ImplicitCastExpr>(expr); 48 if (ToVoidPtr == nullptr) { 49 return nullptr; 50 } 51 52 clang::ImplicitCastExpr* Decay = 53 clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr()); 54 55 if (Decay == nullptr) { 56 return nullptr; 57 } 58 59 clang::DeclRefExpr* DRE = 60 clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr()); 61 62 if (DRE == nullptr) { 63 return nullptr; 64 } 65 66 const clang::FunctionDecl* FD = 67 clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl()); 68 69 if (FD == nullptr) { 70 return nullptr; 71 } 72 73 return FD; 74} 75 76// Checks if the call expression is a legal rsForEach call by looking for the 77// following pattern in the AST. On success, returns the first argument that is 78// a FunctionDecl of a kernel function. 79// 80// CallExpr 'void' 81// | 82// |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay> 83// | `-DeclRefExpr 'void (void *, ...)' 'rsForEach' 'void (void *, ...)' 84// | 85// |-ImplicitCastExpr 'void *' <BitCast> 86// | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay> 87// | `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)' 88// | 89// |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue> 90// | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation' 91// | 92// `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue> 93// `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation' 94const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall( 95 clang::CallExpr* CE, int* slot, bool* hasOptions) { 96 const clang::Decl* D = CE->getCalleeDecl(); 97 const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D); 98 99 if (FD == nullptr) { 100 return nullptr; 101 } 102 103 const clang::StringRef& funcName = FD->getName(); 104 105 if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) { 106 *hasOptions = false; 107 } else if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS)) { 108 *hasOptions = true; 109 } else { 110 return nullptr; 111 } 112 113 clang::Expr* arg0 = CE->getArg(0); 114 const clang::FunctionDecl* kernel = matchFunctionDesignator(arg0); 115 116 if (kernel == nullptr) { 117 mCtxt->ReportError(arg0->getExprLoc(), 118 "Invalid kernel launch call. " 119 "Expects a function designator for the first argument."); 120 return nullptr; 121 } 122 123 // Verifies that kernel is indeed a "kernel" function. 124 *slot = mCtxt->getForEachSlotNumber(kernel); 125 if (*slot == -1) { 126 mCtxt->ReportError(CE->getExprLoc(), "%0 applied to non kernel function %1") 127 << funcName << kernel->getName(); 128 return nullptr; 129 } 130 131 return kernel; 132} 133 134// Create an AST node for the declaration of rsForEachInternal 135clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() { 136 clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl(); 137 clang::SourceLocation Loc; 138 139 llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME); 140 clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR); 141 clang::DeclarationName N(&II); 142 143 clang::FunctionProtoType::ExtProtoInfo EPI; 144 145 const clang::QualType& AllocTy = mCtxt->getAllocationType(); 146 clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy); 147 148 clang::QualType ScriptCallTy = mCtxt->getScriptCallType(); 149 const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy); 150 151 clang::QualType T = mASTCtxt.getFunctionType( 152 mASTCtxt.VoidTy, // Return type 153 // Argument types: 154 { mASTCtxt.IntTy, // int slot 155 ScriptCallPtrTy, // rs_script_call_t* launch_options 156 mASTCtxt.IntTy, // int numOutput 157 mASTCtxt.IntTy, // int numInputs 158 AllocPtrTy // rs_allocation* allocs 159 }, 160 EPI); 161 162 clang::FunctionDecl* FD = clang::FunctionDecl::Create( 163 mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern); 164 165 return FD; 166} 167 168// Create an expression like the following that references the rsForEachInternal to 169// replace the callee in the original call expression that references rsForEach. 170// 171// ImplicitCastExpr 'void (*)(int, rs_script_call_t*, int, int, rs_allocation*)' <FunctionToPointerDecay> 172// `-DeclRefExpr 'void' Function '_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation' 'void (int, rs_script_call_t*, int, int, rs_allocation*)' 173clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() { 174 clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl(); 175 176 const clang::QualType FDNewType = FDNew->getType(); 177 178 clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create( 179 mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew, 180 false, clang::SourceLocation(), FDNewType, clang::VK_RValue); 181 182 clang::Expr* calleeNew = clang::ImplicitCastExpr::Create( 183 mASTCtxt, mASTCtxt.getPointerType(FDNewType), 184 clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue); 185 186 return calleeNew; 187} 188 189// This visit method checks (via pattern matching) if the call expression is to 190// rsForEach, and the arguments satisfy the restrictions on the 191// rsForEach API. If so, replace the call with a rsForEachInternal call 192// with the first argument replaced by the slot number of the kernel function 193// referenced in the original first argument. 194// 195// See comments to the helper methods defined above for details. 196void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) { 197 int slot; 198 bool hasOptions; 199 const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE, &slot, &hasOptions); 200 if (kernel == nullptr) { 201 return; 202 } 203 204 slangAssert(slot >= 0); 205 206 const unsigned numArgsOrig = CE->getNumArgs(); 207 208 clang::QualType resultType = kernel->getReturnType().getCanonicalType(); 209 const unsigned numOutputsExpected = resultType->isVoidType() ? 0 : 1; 210 211 const unsigned numInputsExpected = RSExportForEach::getNumInputs(mCtxt->getTargetAPI(), kernel); 212 213 // Verifies that rsForEach takes the right number of input and output allocations. 214 // TODO: Check input/output allocation types match kernel function expectation. 215 const unsigned numAllocations = numArgsOrig - (hasOptions ? 2 : 1); 216 if (numInputsExpected + numOutputsExpected != numAllocations) { 217 mCtxt->ReportError( 218 CE->getExprLoc(), 219 "Number of input and output allocations unexpected for kernel function %0") 220 << kernel->getName(); 221 return; 222 } 223 224 clang::Expr* calleeNew = CreateCalleeExprForInternalForEach(); 225 CE->setCallee(calleeNew); 226 227 const clang::CanQualType IntTy = mASTCtxt.IntTy; 228 const unsigned IntTySize = mASTCtxt.getTypeSize(IntTy); 229 const llvm::APInt APIntSlot(IntTySize, slot); 230 const clang::Expr* arg0 = CE->getArg(0); 231 const clang::SourceLocation Loc(arg0->getLocStart()); 232 clang::Expr* IntSlotNum = 233 clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, IntTy, Loc); 234 CE->setArg(0, IntSlotNum); 235 236 /* 237 The last few arguments to rsForEach or rsForEachWithOptions are allocations. 238 Creates a new compound literal of an array initialized with those values, and 239 passes it to rsForEachInternal as the last (the 5th) argument. 240 241 For example, rsForEach(foo, ain1, ain2, aout) would be translated into 242 rsForEachInternal( 243 1, // Slot number for kernel 244 NULL, // Launch options 245 2, // Number of input allocations 246 1, // Number of output allocations 247 (rs_allocation[]){ain1, ain2, aout) // Input and output allocations 248 ); 249 250 The AST for the rs_allocation array looks like following: 251 252 ImplicitCastExpr 0x99575670 'struct rs_allocation *' <ArrayToPointerDecay> 253 `-CompoundLiteralExpr 0x99575648 'struct rs_allocation [3]' lvalue 254 `-InitListExpr 0x99575590 'struct rs_allocation [3]' 255 |-ImplicitCastExpr 0x99574b38 'rs_allocation':'struct rs_allocation' <LValueToRValue> 256 | `-DeclRefExpr 0x99574a08 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c408 'ain1' 'rs_allocation':'struct rs_allocation' 257 |-ImplicitCastExpr 0x99574b50 'rs_allocation':'struct rs_allocation' <LValueToRValue> 258 | `-DeclRefExpr 0x99574a30 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'ain2' 'rs_allocation':'struct rs_allocation' 259 `-ImplicitCastExpr 0x99574b68 'rs_allocation':'struct rs_allocation' <LValueToRValue> 260 `-DeclRefExpr 0x99574a58 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'aout' 'rs_allocation':'struct rs_allocation' 261 */ 262 263 const clang::QualType& AllocTy = mCtxt->getAllocationType(); 264 const llvm::APInt APIntNumAllocs(IntTySize, numAllocations); 265 clang::QualType AllocArrayTy = mASTCtxt.getConstantArrayType( 266 AllocTy, 267 APIntNumAllocs, 268 clang::ArrayType::ArraySizeModifier::Normal, 269 0 // index type qualifiers 270 ); 271 272 const int allocArgIndexEnd = numArgsOrig - 1; 273 int allocArgIndexStart = allocArgIndexEnd; 274 275 clang::Expr** args = CE->getArgs(); 276 277 clang::SourceLocation lparenloc; 278 clang::SourceLocation rparenloc; 279 280 if (numAllocations > 0) { 281 allocArgIndexStart = hasOptions ? 2 : 1; 282 lparenloc = args[allocArgIndexStart]->getExprLoc(); 283 rparenloc = args[allocArgIndexEnd]->getExprLoc(); 284 } 285 286 clang::InitListExpr* init = new (mASTCtxt) clang::InitListExpr( 287 mASTCtxt, 288 lparenloc, 289 llvm::ArrayRef<clang::Expr*>(args + allocArgIndexStart, numAllocations), 290 rparenloc); 291 init->setType(AllocArrayTy); 292 293 clang::TypeSourceInfo* ti = mASTCtxt.getTrivialTypeSourceInfo(AllocArrayTy); 294 clang::CompoundLiteralExpr* CLE = new (mASTCtxt) clang::CompoundLiteralExpr( 295 lparenloc, 296 ti, 297 AllocArrayTy, 298 clang::VK_LValue, // A compound literal is an l-value in C. 299 init, 300 false // Not file scope 301 ); 302 303 const clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy); 304 305 clang::ImplicitCastExpr* Decay = clang::ImplicitCastExpr::Create( 306 mASTCtxt, 307 AllocPtrTy, 308 clang::CK_ArrayToPointerDecay, 309 CLE, 310 nullptr, // C++ cast path 311 clang::VK_RValue 312 ); 313 314 CE->setNumArgs(mASTCtxt, 5); 315 316 CE->setArg(4, Decay); 317 318 // Sets the new arguments for NULL launch option (if the user does not set one), 319 // the number of outputs, and the number of inputs. 320 321 if (!hasOptions) { 322 const llvm::APInt APIntZero(IntTySize, 0); 323 clang::Expr* IntNull = 324 clang::IntegerLiteral::Create(mASTCtxt, APIntZero, IntTy, Loc); 325 clang::QualType ScriptCallTy = mCtxt->getScriptCallType(); 326 const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy); 327 clang::CStyleCastExpr* Cast = 328 clang::CStyleCastExpr::Create(mASTCtxt, 329 ScriptCallPtrTy, 330 clang::VK_RValue, 331 clang::CK_NullToPointer, 332 IntNull, 333 nullptr, 334 mASTCtxt.getTrivialTypeSourceInfo(ScriptCallPtrTy), 335 clang::SourceLocation(), 336 clang::SourceLocation()); 337 CE->setArg(1, Cast); 338 } 339 340 const llvm::APInt APIntNumOutput(IntTySize, numOutputsExpected); 341 clang::Expr* IntNumOutput = 342 clang::IntegerLiteral::Create(mASTCtxt, APIntNumOutput, IntTy, Loc); 343 CE->setArg(2, IntNumOutput); 344 345 const llvm::APInt APIntNumInputs(IntTySize, numInputsExpected); 346 clang::Expr* IntNumInputs = 347 clang::IntegerLiteral::Create(mASTCtxt, APIntNumInputs, IntTy, Loc); 348 CE->setArg(3, IntNumInputs); 349} 350 351void RSForEachLowering::VisitStmt(clang::Stmt* S) { 352 for (clang::Stmt* Child : S->children()) { 353 if (Child) { 354 Visit(Child); 355 } 356 } 357} 358 359} // namespace slang 360