slang_rs_foreach_lowering.cpp revision 88f21e16250d2e52a75607b7f0c396e1c2a34201
1/*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_foreach_lowering.h"
18
19#include "clang/AST/ASTContext.h"
20#include "llvm/Support/raw_ostream.h"
21#include "slang_rs_context.h"
22#include "slang_rs_export_foreach.h"
23
24namespace slang {
25
26namespace {
27
28const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsForEach";
29const char KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS[] = "rsForEachWithOptions";
30const char INTERNAL_LAUNCH_FUNCTION_NAME[] =
31    "_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation";
32
33}  // anonymous namespace
34
35RSForEachLowering::RSForEachLowering(RSContext* ctxt)
36    : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {}
37
38// Check if the passed-in expr references a kernel function in the following
39// pattern in the AST.
40//
41// ImplicitCastExpr 'void *' <BitCast>
42//  `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
43//    `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
44const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator(
45    clang::Expr* expr) {
46  clang::ImplicitCastExpr* ToVoidPtr =
47      clang::dyn_cast<clang::ImplicitCastExpr>(expr);
48  if (ToVoidPtr == nullptr) {
49    return nullptr;
50  }
51
52  clang::ImplicitCastExpr* Decay =
53      clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr());
54
55  if (Decay == nullptr) {
56    return nullptr;
57  }
58
59  clang::DeclRefExpr* DRE =
60      clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr());
61
62  if (DRE == nullptr) {
63    return nullptr;
64  }
65
66  const clang::FunctionDecl* FD =
67      clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl());
68
69  if (FD == nullptr) {
70    return nullptr;
71  }
72
73  return FD;
74}
75
76// Checks if the call expression is a legal rsForEach call by looking for the
77// following pattern in the AST. On success, returns the first argument that is
78// a FunctionDecl of a kernel function.
79//
80// CallExpr 'void'
81// |
82// |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay>
83// | `-DeclRefExpr  'void (void *, ...)'  'rsForEach' 'void (void *, ...)'
84// |
85// |-ImplicitCastExpr 'void *' <BitCast>
86// | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
87// |   `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
88// |
89// |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
90// | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation'
91// |
92// `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
93//   `-DeclRefExpr  'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation'
94const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall(
95    clang::CallExpr* CE, int* slot, bool* hasOptions) {
96  const clang::Decl* D = CE->getCalleeDecl();
97  const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D);
98
99  if (FD == nullptr) {
100    return nullptr;
101  }
102
103  const clang::StringRef& funcName = FD->getName();
104
105  if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) {
106    *hasOptions = false;
107  } else if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS)) {
108    *hasOptions = true;
109  } else {
110    return nullptr;
111  }
112
113  clang::Expr* arg0 = CE->getArg(0);
114  const clang::FunctionDecl* kernel = matchFunctionDesignator(arg0);
115
116  if (kernel == nullptr) {
117    mCtxt->ReportError(arg0->getExprLoc(),
118                       "Invalid kernel launch call. "
119                       "Expects a function designator for the first argument.");
120    return nullptr;
121  }
122
123  // Verifies that kernel is indeed a "kernel" function.
124  *slot = mCtxt->getForEachSlotNumber(kernel);
125  if (*slot == -1) {
126    mCtxt->ReportError(CE->getExprLoc(), "%0 applied to non kernel function %1")
127            << funcName << kernel->getName();
128    return nullptr;
129  }
130
131  return kernel;
132}
133
134// Create an AST node for the declaration of rsForEachInternal
135clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() {
136  clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl();
137  clang::SourceLocation Loc;
138
139  llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME);
140  clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR);
141  clang::DeclarationName N(&II);
142
143  clang::FunctionProtoType::ExtProtoInfo EPI;
144
145  const clang::QualType& AllocTy = mCtxt->getAllocationType();
146  clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
147
148  clang::QualType T = mASTCtxt.getFunctionType(
149      mASTCtxt.VoidTy,       // Return type
150                             // Argument types:
151      { mASTCtxt.IntTy,      // int slot
152        mASTCtxt.VoidPtrTy,  // rs_script_call_t* launch_options
153        mASTCtxt.IntTy,      // int numOutput
154        mASTCtxt.IntTy,      // int numInputs
155        AllocPtrTy           // rs_allocation* allocs
156      },
157      EPI);
158
159  clang::FunctionDecl* FD = clang::FunctionDecl::Create(
160      mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern);
161
162  return FD;
163}
164
165// Create an expression like the following that references the rsForEachInternal to
166// replace the callee in the original call expression that references rsForEach.
167//
168// ImplicitCastExpr 'void (*)(int, rs_allocation, rs_allocation)' <FunctionToPointerDecay>
169// `-DeclRefExpr 'void' Function '_Z17rsForEachInternali13rs_allocationS_' 'void (int, rs_allocation, rs_allocation)'
170clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() {
171  clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl();
172
173  clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create(
174      mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew,
175      false, clang::SourceLocation(), mASTCtxt.VoidTy, clang::VK_RValue);
176
177  const clang::QualType FDNewType = FDNew->getType();
178
179  clang::Expr* calleeNew = clang::ImplicitCastExpr::Create(
180      mASTCtxt, mASTCtxt.getPointerType(FDNewType),
181      clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue);
182
183  return calleeNew;
184}
185
186// This visit method checks (via pattern matching) if the call expression is to
187// rsForEach, and the arguments satisfy the restrictions on the
188// rsForEach API. If so, replace the call with a rsForEachInternal call
189// with the first argument replaced by the slot number of the kernel function
190// referenced in the original first argument.
191//
192// See comments to the helper methods defined above for details.
193void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) {
194  int slot;
195  bool hasOptions;
196  const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE, &slot, &hasOptions);
197  if (kernel == nullptr) {
198    return;
199  }
200
201  slangAssert(slot >= 0);
202
203  const unsigned numArgsOrig = CE->getNumArgs();
204
205  clang::QualType resultType = kernel->getReturnType().getCanonicalType();
206  const unsigned numOutputsExpected = resultType->isVoidType() ? 0 : 1;
207
208  const unsigned numInputsExpected = RSExportForEach::getNumInputs(mCtxt->getTargetAPI(), kernel);
209
210  // Verifies that rsForEach takes the right number of input and output allocations.
211  // TODO: Check input/output allocation types match kernel function expectation.
212  const unsigned numAllocations = numArgsOrig - (hasOptions ? 2 : 1);
213  if (numInputsExpected + numOutputsExpected != numAllocations) {
214    mCtxt->ReportError(
215      CE->getExprLoc(),
216      "Number of input and output allocations unexpected for kernel function %0")
217    << kernel->getName();
218    return;
219  }
220
221  clang::Expr* calleeNew = CreateCalleeExprForInternalForEach();
222  CE->setCallee(calleeNew);
223
224  const clang::CanQualType IntTy = mASTCtxt.IntTy;
225  const unsigned IntTySize = mASTCtxt.getTypeSize(IntTy);
226  const llvm::APInt APIntSlot(IntTySize, slot);
227  const clang::Expr* arg0 = CE->getArg(0);
228  const clang::SourceLocation Loc(arg0->getLocStart());
229  clang::Expr* IntSlotNum =
230      clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, IntTy, Loc);
231  CE->setArg(0, IntSlotNum);
232
233  /*
234    The last few arguments to rsForEach or rsForEachWithOptions are allocations.
235    Creates a new compound literal of an array initialized with those values, and
236    passes it to rsForEachInternal as the last (the 5th) argument.
237
238    For example, rsForEach(foo, ain1, ain2, aout) would be translated into
239    rsForEachInternal(
240        1,                                   // Slot number for kernel
241        NULL,                                // Launch options
242        2,                                   // Number of input allocations
243        1,                                   // Number of output allocations
244        (rs_allocation[]){ain1, ain2, aout)  // Input and output allocations
245    );
246
247    The AST for the rs_allocation array looks like following:
248
249    ImplicitCastExpr 0x99575670 'struct rs_allocation *' <ArrayToPointerDecay>
250    `-CompoundLiteralExpr 0x99575648 'struct rs_allocation [3]' lvalue
251      `-InitListExpr 0x99575590 'struct rs_allocation [3]'
252      |-ImplicitCastExpr 0x99574b38 'rs_allocation':'struct rs_allocation' <LValueToRValue>
253      | `-DeclRefExpr 0x99574a08 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c408 'ain1' 'rs_allocation':'struct rs_allocation'
254      |-ImplicitCastExpr 0x99574b50 'rs_allocation':'struct rs_allocation' <LValueToRValue>
255      | `-DeclRefExpr 0x99574a30 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'ain2' 'rs_allocation':'struct rs_allocation'
256      `-ImplicitCastExpr 0x99574b68 'rs_allocation':'struct rs_allocation' <LValueToRValue>
257        `-DeclRefExpr 0x99574a58 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'aout' 'rs_allocation':'struct rs_allocation'
258  */
259
260  const clang::QualType& AllocTy = mCtxt->getAllocationType();
261  const llvm::APInt APIntNumAllocs(IntTySize, numAllocations);
262  clang::QualType AllocArrayTy = mASTCtxt.getConstantArrayType(
263      AllocTy,
264      APIntNumAllocs,
265      clang::ArrayType::ArraySizeModifier::Normal,
266      0  // index type qualifiers
267  );
268
269  const int allocArgIndexEnd = numArgsOrig - 1;
270  int allocArgIndexStart = allocArgIndexEnd;
271
272  clang::Expr** args = CE->getArgs();
273
274  clang::SourceLocation lparenloc;
275  clang::SourceLocation rparenloc;
276
277  if (numAllocations > 0) {
278    allocArgIndexStart = hasOptions ? 2 : 1;
279    lparenloc = args[allocArgIndexStart]->getExprLoc();
280    rparenloc = args[allocArgIndexEnd]->getExprLoc();
281  }
282
283  clang::InitListExpr* init = new (mASTCtxt) clang::InitListExpr(
284      mASTCtxt,
285      lparenloc,
286      llvm::ArrayRef<clang::Expr*>(args + allocArgIndexStart, numAllocations),
287      rparenloc);
288  init->setType(AllocArrayTy);
289
290  clang::TypeSourceInfo* ti = mASTCtxt.getTrivialTypeSourceInfo(AllocArrayTy);
291  clang::CompoundLiteralExpr* CLE = new (mASTCtxt) clang::CompoundLiteralExpr(
292      lparenloc,
293      ti,
294      AllocArrayTy,
295      clang::VK_LValue,  // A compound literal is an l-value in C.
296      init,
297      false  // Not file scope
298  );
299
300  const clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
301
302  clang::ImplicitCastExpr* Decay = clang::ImplicitCastExpr::Create(
303      mASTCtxt,
304      AllocPtrTy,
305      clang::CK_ArrayToPointerDecay,
306      CLE,
307      nullptr,  // C++ cast path
308      clang::VK_RValue
309  );
310
311  CE->setNumArgs(mASTCtxt, 5);
312
313  CE->setArg(4, Decay);
314
315  // Sets the new arguments for NULL launch option (if the user does not set one),
316  // the number of outputs, and the number of inputs.
317
318  if (!hasOptions) {
319    const llvm::APInt APIntZero(IntTySize, 0);
320    clang::Expr* IntNull =
321        clang::IntegerLiteral::Create(mASTCtxt, APIntZero, IntTy, Loc);
322    CE->setArg(1, IntNull);
323  }
324
325  const llvm::APInt APIntNumOutput(IntTySize, numOutputsExpected);
326  clang::Expr* IntNumOutput =
327      clang::IntegerLiteral::Create(mASTCtxt, APIntNumOutput, IntTy, Loc);
328  CE->setArg(2, IntNumOutput);
329
330  const llvm::APInt APIntNumInputs(IntTySize, numInputsExpected);
331  clang::Expr* IntNumInputs =
332      clang::IntegerLiteral::Create(mASTCtxt, APIntNumInputs, IntTy, Loc);
333  CE->setArg(3, IntNumInputs);
334}
335
336void RSForEachLowering::VisitStmt(clang::Stmt* S) {
337  for (clang::Stmt* Child : S->children()) {
338    if (Child) {
339      Visit(Child);
340    }
341  }
342}
343
344}  // namespace slang
345