slang_rs_foreach_lowering.cpp revision 2615f383dfc1542a05f19aee23b03a09bd018f4e
1/*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_foreach_lowering.h"
18
19#include "clang/AST/ASTContext.h"
20#include "llvm/Support/raw_ostream.h"
21#include "slang_rs_context.h"
22#include "slang_rs_export_foreach.h"
23
24namespace slang {
25
26namespace {
27
28const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsForEach";
29const char KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS[] = "rsForEachWithOptions";
30const char INTERNAL_LAUNCH_FUNCTION_NAME[] =
31    "_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation";
32
33}  // anonymous namespace
34
35RSForEachLowering::RSForEachLowering(RSContext* ctxt)
36    : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {}
37
38// Check if the passed-in expr references a kernel function in the following
39// pattern in the AST.
40//
41// ImplicitCastExpr 'void *' <BitCast>
42//  `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
43//    `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
44const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator(
45    clang::Expr* expr) {
46  clang::ImplicitCastExpr* ToVoidPtr =
47      clang::dyn_cast<clang::ImplicitCastExpr>(expr);
48  if (ToVoidPtr == nullptr) {
49    return nullptr;
50  }
51
52  clang::ImplicitCastExpr* Decay =
53      clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr());
54
55  if (Decay == nullptr) {
56    return nullptr;
57  }
58
59  clang::DeclRefExpr* DRE =
60      clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr());
61
62  if (DRE == nullptr) {
63    return nullptr;
64  }
65
66  const clang::FunctionDecl* FD =
67      clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl());
68
69  if (FD == nullptr) {
70    return nullptr;
71  }
72
73  return FD;
74}
75
76// Checks if the call expression is a legal rsForEach call by looking for the
77// following pattern in the AST. On success, returns the first argument that is
78// a FunctionDecl of a kernel function.
79//
80// CallExpr 'void'
81// |
82// |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay>
83// | `-DeclRefExpr  'void (void *, ...)'  'rsForEach' 'void (void *, ...)'
84// |
85// |-ImplicitCastExpr 'void *' <BitCast>
86// | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
87// |   `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
88// |
89// |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
90// | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation'
91// |
92// `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
93//   `-DeclRefExpr  'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation'
94const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall(
95    clang::CallExpr* CE, int* slot, bool* hasOptions) {
96  const clang::Decl* D = CE->getCalleeDecl();
97  const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D);
98
99  if (FD == nullptr) {
100    return nullptr;
101  }
102
103  const clang::StringRef& funcName = FD->getName();
104
105  if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) {
106    *hasOptions = false;
107  } else if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS)) {
108    *hasOptions = true;
109  } else {
110    return nullptr;
111  }
112
113  clang::Expr* arg0 = CE->getArg(0);
114  const clang::FunctionDecl* kernel = matchFunctionDesignator(arg0);
115
116  if (kernel == nullptr) {
117    mCtxt->ReportError(arg0->getExprLoc(),
118                       "Invalid kernel launch call. "
119                       "Expects a function designator for the first argument.");
120    return nullptr;
121  }
122
123  // Verifies that kernel is indeed a "kernel" function.
124  *slot = mCtxt->getForEachSlotNumber(kernel);
125  if (*slot == -1) {
126    mCtxt->ReportError(CE->getExprLoc(), "%0 applied to non kernel function %1")
127            << funcName << kernel->getName();
128    return nullptr;
129  }
130
131  return kernel;
132}
133
134// Create an AST node for the declaration of rsForEachInternal
135clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() {
136  clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl();
137  clang::SourceLocation Loc;
138
139  llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME);
140  clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR);
141  clang::DeclarationName N(&II);
142
143  clang::FunctionProtoType::ExtProtoInfo EPI;
144
145  const clang::QualType& AllocTy = mCtxt->getAllocationType();
146  clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
147
148  clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
149  const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
150
151  clang::QualType T = mASTCtxt.getFunctionType(
152      mASTCtxt.VoidTy,    // Return type
153                          // Argument types:
154      { mASTCtxt.IntTy,   // int slot
155        ScriptCallPtrTy,  // rs_script_call_t* launch_options
156        mASTCtxt.IntTy,   // int numOutput
157        mASTCtxt.IntTy,   // int numInputs
158        AllocPtrTy        // rs_allocation* allocs
159      },
160      EPI);
161
162  clang::FunctionDecl* FD = clang::FunctionDecl::Create(
163      mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern);
164
165  return FD;
166}
167
168// Create an expression like the following that references the rsForEachInternal to
169// replace the callee in the original call expression that references rsForEach.
170//
171// ImplicitCastExpr 'void (*)(int, rs_script_call_t*, int, int, rs_allocation*)' <FunctionToPointerDecay>
172// `-DeclRefExpr 'void' Function '_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation' 'void (int, rs_script_call_t*, int, int, rs_allocation*)'
173clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() {
174  clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl();
175
176  const clang::QualType FDNewType = FDNew->getType();
177
178  clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create(
179      mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew,
180      false, clang::SourceLocation(), FDNewType, clang::VK_RValue);
181
182  clang::Expr* calleeNew = clang::ImplicitCastExpr::Create(
183      mASTCtxt, mASTCtxt.getPointerType(FDNewType),
184      clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue);
185
186  return calleeNew;
187}
188
189// This visit method checks (via pattern matching) if the call expression is to
190// rsForEach, and the arguments satisfy the restrictions on the
191// rsForEach API. If so, replace the call with a rsForEachInternal call
192// with the first argument replaced by the slot number of the kernel function
193// referenced in the original first argument.
194//
195// See comments to the helper methods defined above for details.
196void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) {
197  int slot;
198  bool hasOptions;
199  const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE, &slot, &hasOptions);
200  if (kernel == nullptr) {
201    return;
202  }
203
204  slangAssert(slot >= 0);
205
206  const unsigned numArgsOrig = CE->getNumArgs();
207
208  clang::QualType resultType = kernel->getReturnType().getCanonicalType();
209  const unsigned numOutputsExpected = resultType->isVoidType() ? 0 : 1;
210
211  const unsigned numInputsExpected = RSExportForEach::getNumInputs(mCtxt->getTargetAPI(), kernel);
212
213  // Verifies that rsForEach takes the right number of input and output allocations.
214  // TODO: Check input/output allocation types match kernel function expectation.
215  const unsigned numAllocations = numArgsOrig - (hasOptions ? 2 : 1);
216  if (numInputsExpected + numOutputsExpected != numAllocations) {
217    mCtxt->ReportError(
218      CE->getExprLoc(),
219      "Number of input and output allocations unexpected for kernel function %0")
220    << kernel->getName();
221    return;
222  }
223
224  clang::Expr* calleeNew = CreateCalleeExprForInternalForEach();
225  CE->setCallee(calleeNew);
226
227  const clang::CanQualType IntTy = mASTCtxt.IntTy;
228  const unsigned IntTySize = mASTCtxt.getTypeSize(IntTy);
229  const llvm::APInt APIntSlot(IntTySize, slot);
230  const clang::Expr* arg0 = CE->getArg(0);
231  const clang::SourceLocation Loc(arg0->getLocStart());
232  clang::Expr* IntSlotNum =
233      clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, IntTy, Loc);
234  CE->setArg(0, IntSlotNum);
235
236  /*
237    The last few arguments to rsForEach or rsForEachWithOptions are allocations.
238    Creates a new compound literal of an array initialized with those values, and
239    passes it to rsForEachInternal as the last (the 5th) argument.
240
241    For example, rsForEach(foo, ain1, ain2, aout) would be translated into
242    rsForEachInternal(
243        1,                                   // Slot number for kernel
244        NULL,                                // Launch options
245        2,                                   // Number of input allocations
246        1,                                   // Number of output allocations
247        (rs_allocation[]){ain1, ain2, aout)  // Input and output allocations
248    );
249
250    The AST for the rs_allocation array looks like following:
251
252    ImplicitCastExpr 0x99575670 'struct rs_allocation *' <ArrayToPointerDecay>
253    `-CompoundLiteralExpr 0x99575648 'struct rs_allocation [3]' lvalue
254      `-InitListExpr 0x99575590 'struct rs_allocation [3]'
255      |-ImplicitCastExpr 0x99574b38 'rs_allocation':'struct rs_allocation' <LValueToRValue>
256      | `-DeclRefExpr 0x99574a08 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c408 'ain1' 'rs_allocation':'struct rs_allocation'
257      |-ImplicitCastExpr 0x99574b50 'rs_allocation':'struct rs_allocation' <LValueToRValue>
258      | `-DeclRefExpr 0x99574a30 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'ain2' 'rs_allocation':'struct rs_allocation'
259      `-ImplicitCastExpr 0x99574b68 'rs_allocation':'struct rs_allocation' <LValueToRValue>
260        `-DeclRefExpr 0x99574a58 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'aout' 'rs_allocation':'struct rs_allocation'
261  */
262
263  const clang::QualType& AllocTy = mCtxt->getAllocationType();
264  const llvm::APInt APIntNumAllocs(IntTySize, numAllocations);
265  clang::QualType AllocArrayTy = mASTCtxt.getConstantArrayType(
266      AllocTy,
267      APIntNumAllocs,
268      clang::ArrayType::ArraySizeModifier::Normal,
269      0  // index type qualifiers
270  );
271
272  const int allocArgIndexEnd = numArgsOrig - 1;
273  int allocArgIndexStart = allocArgIndexEnd;
274
275  clang::Expr** args = CE->getArgs();
276
277  clang::SourceLocation lparenloc;
278  clang::SourceLocation rparenloc;
279
280  if (numAllocations > 0) {
281    allocArgIndexStart = hasOptions ? 2 : 1;
282    lparenloc = args[allocArgIndexStart]->getExprLoc();
283    rparenloc = args[allocArgIndexEnd]->getExprLoc();
284  }
285
286  clang::InitListExpr* init = new (mASTCtxt) clang::InitListExpr(
287      mASTCtxt,
288      lparenloc,
289      llvm::ArrayRef<clang::Expr*>(args + allocArgIndexStart, numAllocations),
290      rparenloc);
291  init->setType(AllocArrayTy);
292
293  clang::TypeSourceInfo* ti = mASTCtxt.getTrivialTypeSourceInfo(AllocArrayTy);
294  clang::CompoundLiteralExpr* CLE = new (mASTCtxt) clang::CompoundLiteralExpr(
295      lparenloc,
296      ti,
297      AllocArrayTy,
298      clang::VK_LValue,  // A compound literal is an l-value in C.
299      init,
300      false  // Not file scope
301  );
302
303  const clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
304
305  clang::ImplicitCastExpr* Decay = clang::ImplicitCastExpr::Create(
306      mASTCtxt,
307      AllocPtrTy,
308      clang::CK_ArrayToPointerDecay,
309      CLE,
310      nullptr,  // C++ cast path
311      clang::VK_RValue
312  );
313
314  CE->setNumArgs(mASTCtxt, 5);
315
316  CE->setArg(4, Decay);
317
318  // Sets the new arguments for NULL launch option (if the user does not set one),
319  // the number of outputs, and the number of inputs.
320
321  if (!hasOptions) {
322    const llvm::APInt APIntZero(IntTySize, 0);
323    clang::Expr* IntNull =
324        clang::IntegerLiteral::Create(mASTCtxt, APIntZero, IntTy, Loc);
325    clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
326    const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
327    clang::CStyleCastExpr* Cast =
328        clang::CStyleCastExpr::Create(mASTCtxt,
329                                      ScriptCallPtrTy,
330                                      clang::VK_RValue,
331                                      clang::CK_NullToPointer,
332                                      IntNull,
333                                      nullptr,
334                                      mASTCtxt.getTrivialTypeSourceInfo(ScriptCallPtrTy),
335                                      clang::SourceLocation(),
336                                      clang::SourceLocation());
337    CE->setArg(1, Cast);
338  }
339
340  const llvm::APInt APIntNumOutput(IntTySize, numOutputsExpected);
341  clang::Expr* IntNumOutput =
342      clang::IntegerLiteral::Create(mASTCtxt, APIntNumOutput, IntTy, Loc);
343  CE->setArg(2, IntNumOutput);
344
345  const llvm::APInt APIntNumInputs(IntTySize, numInputsExpected);
346  clang::Expr* IntNumInputs =
347      clang::IntegerLiteral::Create(mASTCtxt, APIntNumInputs, IntTy, Loc);
348  CE->setArg(3, IntNumInputs);
349}
350
351void RSForEachLowering::VisitStmt(clang::Stmt* S) {
352  for (clang::Stmt* Child : S->children()) {
353    if (Child) {
354      Visit(Child);
355    }
356  }
357}
358
359}  // namespace slang
360