1/*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_foreach_lowering.h"
18
19#include "clang/AST/ASTContext.h"
20#include "clang/AST/Attr.h"
21#include "llvm/Support/raw_ostream.h"
22#include "slang_rs_context.h"
23#include "slang_rs_export_foreach.h"
24
25namespace slang {
26
27namespace {
28
29const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsForEach";
30const char KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS[] = "rsForEachWithOptions";
31const char INTERNAL_LAUNCH_FUNCTION_NAME[] =
32    "_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation";
33
34}  // anonymous namespace
35
36RSForEachLowering::RSForEachLowering(RSContext* ctxt)
37    : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {}
38
39// Check if the passed-in expr references a kernel function in the following
40// pattern in the AST.
41//
42// ImplicitCastExpr 'void *' <BitCast>
43//  `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
44//    `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
45const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator(
46    clang::Expr* expr) {
47  clang::ImplicitCastExpr* ToVoidPtr =
48      clang::dyn_cast<clang::ImplicitCastExpr>(expr);
49  if (ToVoidPtr == nullptr) {
50    return nullptr;
51  }
52
53  clang::ImplicitCastExpr* Decay =
54      clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr());
55
56  if (Decay == nullptr) {
57    return nullptr;
58  }
59
60  clang::DeclRefExpr* DRE =
61      clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr());
62
63  if (DRE == nullptr) {
64    return nullptr;
65  }
66
67  const clang::FunctionDecl* FD =
68      clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl());
69
70  if (FD == nullptr) {
71    return nullptr;
72  }
73
74  return FD;
75}
76
77// Checks if the call expression is a legal rsForEach call by looking for the
78// following pattern in the AST. On success, returns the first argument that is
79// a FunctionDecl of a kernel function.
80//
81// CallExpr 'void'
82// |
83// |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay>
84// | `-DeclRefExpr  'void (void *, ...)'  'rsForEach' 'void (void *, ...)'
85// |
86// |-ImplicitCastExpr 'void *' <BitCast>
87// | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
88// |   `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
89// |
90// |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
91// | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation'
92// |
93// `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
94//   `-DeclRefExpr  'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation'
95const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall(
96    clang::CallExpr* CE, int* slot, bool* hasOptions) {
97  const clang::Decl* D = CE->getCalleeDecl();
98  const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D);
99
100  if (FD == nullptr) {
101    return nullptr;
102  }
103
104  const clang::StringRef& funcName = FD->getName();
105
106  if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) {
107    *hasOptions = false;
108  } else if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS)) {
109    *hasOptions = true;
110  } else {
111    return nullptr;
112  }
113
114  if (mInsideKernel) {
115    mCtxt->ReportError(CE->getExprLoc(),
116        "Invalid kernel launch call made from inside another kernel.");
117    return nullptr;
118  }
119
120  clang::Expr* arg0 = CE->getArg(0);
121  const clang::FunctionDecl* kernel = matchFunctionDesignator(arg0);
122
123  if (kernel == nullptr) {
124    mCtxt->ReportError(arg0->getExprLoc(),
125                       "Invalid kernel launch call. "
126                       "Expects a function designator for the first argument.");
127    return nullptr;
128  }
129
130  // Verifies that kernel is indeed a "kernel" function.
131  *slot = mCtxt->getForEachSlotNumber(kernel);
132  if (*slot == -1) {
133    mCtxt->ReportError(CE->getExprLoc(),
134         "%0 applied to function %1 defined without \"kernel\" attribute")
135         << funcName << kernel->getName();
136    return nullptr;
137  }
138
139  return kernel;
140}
141
142// Create an AST node for the declaration of rsForEachInternal
143clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() {
144  clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl();
145  clang::SourceLocation Loc;
146
147  llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME);
148  clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR);
149  clang::DeclarationName N(&II);
150
151  clang::FunctionProtoType::ExtProtoInfo EPI;
152
153  const clang::QualType& AllocTy = mCtxt->getAllocationType();
154  clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
155
156  clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
157  const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
158
159  clang::QualType ParamTypes[] = {
160    mASTCtxt.IntTy,   // int slot
161    ScriptCallPtrTy,  // rs_script_call_t* launch_options
162    mASTCtxt.IntTy,   // int numOutput
163    mASTCtxt.IntTy,   // int numInputs
164    AllocPtrTy        // rs_allocation* allocs
165  };
166
167  clang::QualType T = mASTCtxt.getFunctionType(
168      mASTCtxt.VoidTy,  // Return type
169      ParamTypes,       // Parameter types
170      EPI);
171
172  clang::FunctionDecl* FD = clang::FunctionDecl::Create(
173      mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern);
174
175  static constexpr unsigned kNumParams = sizeof(ParamTypes) / sizeof(ParamTypes[0]);
176  clang::ParmVarDecl *ParamDecls[kNumParams];
177  for (unsigned I = 0; I != kNumParams; ++I) {
178    ParamDecls[I] = clang::ParmVarDecl::Create(mASTCtxt, FD, Loc,
179        Loc, nullptr, ParamTypes[I], nullptr, clang::SC_None, nullptr);
180    // Implicit means that this declaration was created by the compiler, and
181    // not part of the actual source code.
182    ParamDecls[I]->setImplicit();
183  }
184  FD->setParams(llvm::makeArrayRef(ParamDecls, kNumParams));
185
186  // Implicit means that this declaration was created by the compiler, and
187  // not part of the actual source code.
188  FD->setImplicit();
189
190  return FD;
191}
192
193// Create an expression like the following that references the rsForEachInternal to
194// replace the callee in the original call expression that references rsForEach.
195//
196// ImplicitCastExpr 'void (*)(int, rs_script_call_t*, int, int, rs_allocation*)' <FunctionToPointerDecay>
197// `-DeclRefExpr 'void' Function '_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation' 'void (int, rs_script_call_t*, int, int, rs_allocation*)'
198clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() {
199  clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl();
200
201  const clang::QualType FDNewType = FDNew->getType();
202
203  clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create(
204      mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew,
205      false, clang::SourceLocation(), FDNewType, clang::VK_RValue);
206
207  clang::Expr* calleeNew = clang::ImplicitCastExpr::Create(
208      mASTCtxt, mASTCtxt.getPointerType(FDNewType),
209      clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue);
210
211  return calleeNew;
212}
213
214// This visit method checks (via pattern matching) if the call expression is to
215// rsForEach, and the arguments satisfy the restrictions on the
216// rsForEach API. If so, replace the call with a rsForEachInternal call
217// with the first argument replaced by the slot number of the kernel function
218// referenced in the original first argument.
219//
220// See comments to the helper methods defined above for details.
221void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) {
222  int slot;
223  bool hasOptions;
224  const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE, &slot, &hasOptions);
225  if (kernel == nullptr) {
226    return;
227  }
228
229  slangAssert(slot >= 0);
230
231  const unsigned numArgsOrig = CE->getNumArgs();
232
233  clang::QualType resultType = kernel->getReturnType().getCanonicalType();
234  const unsigned numOutputsExpected = resultType->isVoidType() ? 0 : 1;
235
236  const unsigned numInputsExpected = RSExportForEach::getNumInputs(mCtxt->getTargetAPI(), kernel);
237
238  // Verifies that rsForEach takes the right number of input and output allocations.
239  // TODO: Check input/output allocation types match kernel function expectation.
240  const unsigned numAllocations = numArgsOrig - (hasOptions ? 2 : 1);
241  if (numInputsExpected + numOutputsExpected != numAllocations) {
242    mCtxt->ReportError(
243      CE->getExprLoc(),
244      "Number of input and output allocations unexpected for kernel function %0")
245    << kernel->getName();
246    return;
247  }
248
249  clang::Expr* calleeNew = CreateCalleeExprForInternalForEach();
250  CE->setCallee(calleeNew);
251
252  const clang::CanQualType IntTy = mASTCtxt.IntTy;
253  const unsigned IntTySize = mASTCtxt.getTypeSize(IntTy);
254  const llvm::APInt APIntSlot(IntTySize, slot);
255  const clang::Expr* arg0 = CE->getArg(0);
256  const clang::SourceLocation Loc(arg0->getLocStart());
257  clang::Expr* IntSlotNum =
258      clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, IntTy, Loc);
259  CE->setArg(0, IntSlotNum);
260
261  /*
262    The last few arguments to rsForEach or rsForEachWithOptions are allocations.
263    Creates a new compound literal of an array initialized with those values, and
264    passes it to rsForEachInternal as the last (the 5th) argument.
265
266    For example, rsForEach(foo, ain1, ain2, aout) would be translated into
267    rsForEachInternal(
268        1,                                   // Slot number for kernel
269        NULL,                                // Launch options
270        2,                                   // Number of input allocations
271        1,                                   // Number of output allocations
272        (rs_allocation[]){ain1, ain2, aout)  // Input and output allocations
273    );
274
275    The AST for the rs_allocation array looks like following:
276
277    ImplicitCastExpr 0x99575670 'struct rs_allocation *' <ArrayToPointerDecay>
278    `-CompoundLiteralExpr 0x99575648 'struct rs_allocation [3]' lvalue
279      `-InitListExpr 0x99575590 'struct rs_allocation [3]'
280      |-ImplicitCastExpr 0x99574b38 'rs_allocation':'struct rs_allocation' <LValueToRValue>
281      | `-DeclRefExpr 0x99574a08 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c408 'ain1' 'rs_allocation':'struct rs_allocation'
282      |-ImplicitCastExpr 0x99574b50 'rs_allocation':'struct rs_allocation' <LValueToRValue>
283      | `-DeclRefExpr 0x99574a30 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'ain2' 'rs_allocation':'struct rs_allocation'
284      `-ImplicitCastExpr 0x99574b68 'rs_allocation':'struct rs_allocation' <LValueToRValue>
285        `-DeclRefExpr 0x99574a58 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'aout' 'rs_allocation':'struct rs_allocation'
286  */
287
288  const clang::QualType& AllocTy = mCtxt->getAllocationType();
289  const llvm::APInt APIntNumAllocs(IntTySize, numAllocations);
290  clang::QualType AllocArrayTy = mASTCtxt.getConstantArrayType(
291      AllocTy,
292      APIntNumAllocs,
293      clang::ArrayType::ArraySizeModifier::Normal,
294      0  // index type qualifiers
295  );
296
297  const int allocArgIndexEnd = numArgsOrig - 1;
298  int allocArgIndexStart = allocArgIndexEnd;
299
300  clang::Expr** args = CE->getArgs();
301
302  clang::SourceLocation lparenloc;
303  clang::SourceLocation rparenloc;
304
305  if (numAllocations > 0) {
306    allocArgIndexStart = hasOptions ? 2 : 1;
307    lparenloc = args[allocArgIndexStart]->getExprLoc();
308    rparenloc = args[allocArgIndexEnd]->getExprLoc();
309  }
310
311  clang::InitListExpr* init = new (mASTCtxt) clang::InitListExpr(
312      mASTCtxt,
313      lparenloc,
314      llvm::ArrayRef<clang::Expr*>(args + allocArgIndexStart, numAllocations),
315      rparenloc);
316  init->setType(AllocArrayTy);
317
318  clang::TypeSourceInfo* ti = mASTCtxt.getTrivialTypeSourceInfo(AllocArrayTy);
319  clang::CompoundLiteralExpr* CLE = new (mASTCtxt) clang::CompoundLiteralExpr(
320      lparenloc,
321      ti,
322      AllocArrayTy,
323      clang::VK_LValue,  // A compound literal is an l-value in C.
324      init,
325      false  // Not file scope
326  );
327
328  const clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
329
330  clang::ImplicitCastExpr* Decay = clang::ImplicitCastExpr::Create(
331      mASTCtxt,
332      AllocPtrTy,
333      clang::CK_ArrayToPointerDecay,
334      CLE,
335      nullptr,  // C++ cast path
336      clang::VK_RValue
337  );
338
339  CE->setNumArgs(mASTCtxt, 5);
340
341  CE->setArg(4, Decay);
342
343  // Sets the new arguments for NULL launch option (if the user does not set one),
344  // the number of outputs, and the number of inputs.
345
346  if (!hasOptions) {
347    const llvm::APInt APIntZero(IntTySize, 0);
348    clang::Expr* IntNull =
349        clang::IntegerLiteral::Create(mASTCtxt, APIntZero, IntTy, Loc);
350    clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
351    const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
352    clang::CStyleCastExpr* Cast =
353        clang::CStyleCastExpr::Create(mASTCtxt,
354                                      ScriptCallPtrTy,
355                                      clang::VK_RValue,
356                                      clang::CK_NullToPointer,
357                                      IntNull,
358                                      nullptr,
359                                      mASTCtxt.getTrivialTypeSourceInfo(ScriptCallPtrTy),
360                                      clang::SourceLocation(),
361                                      clang::SourceLocation());
362    CE->setArg(1, Cast);
363  }
364
365  const llvm::APInt APIntNumOutput(IntTySize, numOutputsExpected);
366  clang::Expr* IntNumOutput =
367      clang::IntegerLiteral::Create(mASTCtxt, APIntNumOutput, IntTy, Loc);
368  CE->setArg(2, IntNumOutput);
369
370  const llvm::APInt APIntNumInputs(IntTySize, numInputsExpected);
371  clang::Expr* IntNumInputs =
372      clang::IntegerLiteral::Create(mASTCtxt, APIntNumInputs, IntTy, Loc);
373  CE->setArg(3, IntNumInputs);
374}
375
376void RSForEachLowering::VisitStmt(clang::Stmt* S) {
377  for (clang::Stmt* Child : S->children()) {
378    if (Child) {
379      Visit(Child);
380    }
381  }
382}
383
384void RSForEachLowering::handleForEachCalls(clang::FunctionDecl* FD,
385                                           unsigned int targetAPI) {
386  slangAssert(FD && FD->hasBody());
387
388  mInsideKernel = FD->hasAttr<clang::KernelAttr>();
389  VisitStmt(FD->getBody());
390}
391
392}  // namespace slang
393