slang_rs_export_foreach.cpp revision f075ffc10278e1c127bcf041fce7ce89d428f94c
1/*
2 * Copyright 2011-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Attr.h"
23#include "clang/AST/Decl.h"
24#include "clang/AST/TypeLoc.h"
25
26#include "llvm/IR/DerivedTypes.h"
27
28#include "bcinfo/MetadataExtractor.h"
29
30#include "slang_assert.h"
31#include "slang_rs_context.h"
32#include "slang_rs_export_type.h"
33#include "slang_rs_special_func.h"
34#include "slang_rs_special_kernel_param.h"
35#include "slang_version.h"
36
37namespace {
38
39const size_t RS_KERNEL_INPUT_LIMIT = 8; // see frameworks/base/libs/rs/cpu_ref/rsCpuCoreRuntime.h
40
41bool isRootRSFunc(const clang::FunctionDecl *FD) {
42  if (!FD) {
43    return false;
44  }
45  return FD->getName().equals("root");
46}
47
48} // end anonymous namespace
49
50namespace slang {
51
52// This function takes care of additional validation and construction of
53// parameters related to forEach_* reflection.
54bool RSExportForEach::validateAndConstructParams(
55    RSContext *Context, const clang::FunctionDecl *FD) {
56  slangAssert(Context && FD);
57  bool valid = true;
58
59  numParams = FD->getNumParams();
60
61  if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
62    // Before JellyBean, we allowed only one kernel per file.  It must be called "root".
63    if (!isRootRSFunc(FD)) {
64      Context->ReportError(FD->getLocation(),
65                           "Non-root compute kernel %0() is "
66                           "not supported in SDK levels %1-%2")
67          << FD->getName() << SLANG_MINIMUM_TARGET_API
68          << (SLANG_JB_TARGET_API - 1);
69      return false;
70    }
71  }
72
73  mResultType = FD->getReturnType().getCanonicalType();
74  // Compute kernel functions are defined differently when the
75  // "__attribute__((kernel))" is set.
76  if (FD->hasAttr<clang::KernelAttr>()) {
77    valid &= validateAndConstructKernelParams(Context, FD);
78  } else {
79    valid &= validateAndConstructOldStyleParams(Context, FD);
80  }
81
82  valid &= setSignatureMetadata(Context, FD);
83  return valid;
84}
85
86bool RSExportForEach::validateAndConstructOldStyleParams(
87    RSContext *Context, const clang::FunctionDecl *FD) {
88  slangAssert(Context && FD);
89  // If numParams is 0, we already marked this as a graphics root().
90  slangAssert(numParams > 0);
91
92  bool valid = true;
93
94  // Compute kernel functions of this style are required to return a void type.
95  clang::ASTContext &C = Context->getASTContext();
96  if (mResultType != C.VoidTy) {
97    Context->ReportError(FD->getLocation(),
98                         "Compute kernel %0() is required to return a "
99                         "void type")
100        << FD->getName();
101    valid = false;
102  }
103
104  // Validate remaining parameter types
105
106  size_t IndexOfFirstSpecialParameter = numParams;
107  valid &= processSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
108
109  // Validate the non-special parameters, which should all be found before the
110  // first special parameter.
111  for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
112    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
113    clang::QualType QT = PVD->getType().getCanonicalType();
114
115    if (!QT->isPointerType()) {
116      Context->ReportError(PVD->getLocation(),
117                           "Compute kernel %0() cannot have non-pointer "
118                           "parameters besides special parameters (%1). Parameter '%2' is "
119                           "of type: '%3'")
120          << FD->getName() << listSpecialKernelParameters(Context->getTargetAPI())
121          << PVD->getName() << PVD->getType().getAsString();
122      valid = false;
123      continue;
124    }
125
126    // The only non-const pointer should be out.
127    if (!QT->getPointeeType().isConstQualified()) {
128      if (mOut == nullptr) {
129        mOut = PVD;
130      } else {
131        Context->ReportError(PVD->getLocation(),
132                             "Compute kernel %0() can only have one non-const "
133                             "pointer parameter. Parameters '%1' and '%2' are "
134                             "both non-const.")
135            << FD->getName() << mOut->getName() << PVD->getName();
136        valid = false;
137      }
138    } else {
139      if (mIns.empty() && mOut == nullptr) {
140        mIns.push_back(PVD);
141      } else if (mUsrData == nullptr) {
142        mUsrData = PVD;
143      } else {
144        Context->ReportError(
145            PVD->getLocation(),
146            "Unexpected parameter '%0' for compute kernel %1()")
147            << PVD->getName() << FD->getName();
148        valid = false;
149      }
150    }
151  }
152
153  if (mIns.empty() && !mOut) {
154    Context->ReportError(FD->getLocation(),
155                         "Compute kernel %0() must have at least one "
156                         "parameter for in or out")
157        << FD->getName();
158    valid = false;
159  }
160
161  return valid;
162}
163
164bool RSExportForEach::validateAndConstructKernelParams(
165    RSContext *Context, const clang::FunctionDecl *FD) {
166  slangAssert(Context && FD);
167  bool valid = true;
168  clang::ASTContext &C = Context->getASTContext();
169
170  if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
171    Context->ReportError(FD->getLocation(),
172                         "Compute kernel %0() targeting SDK levels "
173                         "%1-%2 may not use pass-by-value with "
174                         "__attribute__((kernel))")
175        << FD->getName() << SLANG_MINIMUM_TARGET_API
176        << (SLANG_JB_MR1_TARGET_API - 1);
177    return false;
178  }
179
180  // Denote that we are indeed a pass-by-value kernel.
181  mIsKernelStyle = true;
182  mHasReturnType = (mResultType != C.VoidTy);
183
184  if (mResultType->isPointerType()) {
185    Context->ReportError(
186        FD->getTypeSpecStartLoc(),
187        "Compute kernel %0() cannot return a pointer type: '%1'")
188        << FD->getName() << mResultType.getAsString();
189    valid = false;
190  }
191
192  // Validate remaining parameter types
193
194  size_t IndexOfFirstSpecialParameter = numParams;
195  valid &= processSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
196
197  // Validate the non-special parameters, which should all be found before the
198  // first special.
199  for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
200    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
201
202    if (Context->getTargetAPI() >= SLANG_M_TARGET_API || i == 0) {
203      if (i >= RS_KERNEL_INPUT_LIMIT) {
204        Context->ReportError(PVD->getLocation(),
205                             "Invalid parameter '%0' for compute kernel %1(). "
206                             "Kernels targeting SDK levels %2+ may not use "
207                             "more than %3 input parameters.") << PVD->getName() <<
208                             FD->getName() << SLANG_M_TARGET_API <<
209                             int(RS_KERNEL_INPUT_LIMIT);
210
211      } else {
212        mIns.push_back(PVD);
213      }
214    } else {
215      Context->ReportError(PVD->getLocation(),
216                           "Invalid parameter '%0' for compute kernel %1(). "
217                           "Kernels targeting SDK levels %2-%3 may not use "
218                           "multiple input parameters.") << PVD->getName() <<
219                           FD->getName() << SLANG_MINIMUM_TARGET_API <<
220                           (SLANG_M_TARGET_API - 1);
221      valid = false;
222    }
223    clang::QualType QT = PVD->getType().getCanonicalType();
224    if (QT->isPointerType()) {
225      Context->ReportError(PVD->getLocation(),
226                           "Compute kernel %0() cannot have "
227                           "parameter '%1' of pointer type: '%2'")
228          << FD->getName() << PVD->getName() << PVD->getType().getAsString();
229      valid = false;
230    }
231  }
232
233  // Check that we have at least one allocation to use for dimensions.
234  if (valid && mIns.empty() && !mHasReturnType && Context->getTargetAPI() < SLANG_M_TARGET_API) {
235    Context->ReportError(FD->getLocation(),
236                         "Compute kernel %0() targeting SDK levels "
237                         "%1-%2 must have at least one "
238                         "input parameter or a non-void return "
239                         "type")
240        << FD->getName() << SLANG_MINIMUM_TARGET_API
241        << (SLANG_M_TARGET_API - 1);
242    valid = false;
243  }
244
245  return valid;
246}
247
248// Process the optional special parameters:
249// - Sets *IndexOfFirstSpecialParameter to the index of the first special parameter, or
250//     FD->getNumParams() if none are found.
251// - Add bits to mSpecialParameterSignatureMetadata for the found special parameters.
252// Returns true if no errors.
253bool RSExportForEach::processSpecialParameters(
254    RSContext *Context, const clang::FunctionDecl *FD,
255    size_t *IndexOfFirstSpecialParameter) {
256  auto DiagnosticCallback = [FD] {
257    std::ostringstream DiagnosticDescription;
258    DiagnosticDescription << "compute kernel " << FD->getName().str() << "()";
259    return DiagnosticDescription.str();
260  };
261  return slang::processSpecialKernelParameters(Context,
262                                               DiagnosticCallback,
263                                               FD,
264                                               IndexOfFirstSpecialParameter,
265                                               &mSpecialParameterSignatureMetadata);
266}
267
268bool RSExportForEach::setSignatureMetadata(RSContext *Context,
269                                           const clang::FunctionDecl *FD) {
270  mSignatureMetadata = 0;
271  bool valid = true;
272
273  if (mIsKernelStyle) {
274    slangAssert(mOut == nullptr);
275    slangAssert(mUsrData == nullptr);
276  } else {
277    slangAssert(!mHasReturnType);
278  }
279
280  // Set up the bitwise metadata encoding for runtime argument passing.
281  const bool HasOut = mOut || mHasReturnType;
282  mSignatureMetadata |= (hasIns() ?       bcinfo::MD_SIG_In     : 0);
283  mSignatureMetadata |= (HasOut ?         bcinfo::MD_SIG_Out    : 0);
284  mSignatureMetadata |= (mUsrData ?       bcinfo::MD_SIG_Usr    : 0);
285  mSignatureMetadata |= (mIsKernelStyle ? bcinfo::MD_SIG_Kernel : 0);  // pass-by-value
286  mSignatureMetadata |= mSpecialParameterSignatureMetadata;
287
288  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
289    // APIs before ICS cannot skip between parameters. It is ok, however, for
290    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
291    if (mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
292                               bcinfo::MD_SIG_X | bcinfo::MD_SIG_Y) &&
293        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
294                               bcinfo::MD_SIG_X) &&
295        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr) &&
296        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out) &&
297        mSignatureMetadata != (bcinfo::MD_SIG_In)) {
298      Context->ReportError(FD->getLocation(),
299                           "Compute kernel %0() targeting SDK levels "
300                           "%1-%2 may not skip parameters")
301          << FD->getName() << SLANG_MINIMUM_TARGET_API
302          << (SLANG_ICS_TARGET_API - 1);
303      valid = false;
304    }
305  }
306  return valid;
307}
308
309RSExportForEach *RSExportForEach::Create(RSContext *Context,
310                                         const clang::FunctionDecl *FD) {
311  slangAssert(Context && FD);
312  llvm::StringRef Name = FD->getName();
313  RSExportForEach *FE;
314
315  slangAssert(!Name.empty() && "Function must have a name");
316
317  FE = new RSExportForEach(Context, Name);
318
319  if (!FE->validateAndConstructParams(Context, FD)) {
320    return nullptr;
321  }
322
323  clang::ASTContext &Ctx = Context->getASTContext();
324
325  std::string Id = CreateDummyName("helper_foreach_param", FE->getName());
326
327  // Extract the usrData parameter (if we have one)
328  if (FE->mUsrData) {
329    const clang::ParmVarDecl *PVD = FE->mUsrData;
330    clang::QualType QT = PVD->getType().getCanonicalType();
331    slangAssert(QT->isPointerType() &&
332                QT->getPointeeType().isConstQualified());
333
334    const clang::ASTContext &C = Context->getASTContext();
335    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
336        C.VoidTy) {
337      // In the case of using const void*, we can't reflect an appopriate
338      // Java type, so we fall back to just reflecting the ain/aout parameters
339      FE->mUsrData = nullptr;
340    } else {
341      clang::RecordDecl *RD =
342          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
343                                    Ctx.getTranslationUnitDecl(),
344                                    clang::SourceLocation(),
345                                    clang::SourceLocation(),
346                                    &Ctx.Idents.get(Id));
347
348      clang::FieldDecl *FD =
349          clang::FieldDecl::Create(Ctx,
350                                   RD,
351                                   clang::SourceLocation(),
352                                   clang::SourceLocation(),
353                                   PVD->getIdentifier(),
354                                   QT->getPointeeType(),
355                                   nullptr,
356                                   /* BitWidth = */ nullptr,
357                                   /* Mutable = */ false,
358                                   /* HasInit = */ clang::ICIS_NoInit);
359      RD->addDecl(FD);
360      RD->completeDefinition();
361
362      // Create an export type iff we have a valid usrData type
363      clang::QualType T = Ctx.getTagDeclType(RD);
364      slangAssert(!T.isNull());
365
366      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
367
368      slangAssert(ET && "Failed to export a kernel");
369
370      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
371                  "Parameter packet must be a record");
372
373      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
374    }
375  }
376
377  // Construct type information about inputs and outputs. Return null when
378  // there is an error exporting types.
379
380  bool TypeExportError = false;
381
382  if (FE->hasIns()) {
383    for (InIter BI = FE->mIns.begin(), EI = FE->mIns.end(); BI != EI; BI++) {
384      const clang::Type *T = (*BI)->getType().getCanonicalType().getTypePtr();
385      RSExportType *InExportType = RSExportType::Create(Context, T);
386
387      // It is not an error if we don't export an input type for legacy
388      // kernels. This can happen in the case of a void pointer.
389      if (FE->mIsKernelStyle && !InExportType) {
390        TypeExportError = true;
391      }
392
393      FE->mInTypes.push_back(InExportType);
394    }
395  }
396
397  if (FE->mIsKernelStyle && FE->mHasReturnType) {
398    const clang::Type *ReturnType = FE->mResultType.getTypePtr();
399    FE->mOutType = RSExportType::Create(Context, ReturnType);
400    TypeExportError |= !FE->mOutType;
401  } else if (FE->mOut) {
402    const clang::Type *OutType =
403        FE->mOut->getType().getCanonicalType().getTypePtr();
404    FE->mOutType = RSExportType::Create(Context, OutType);
405    // It is not an error if we don't export an output type.
406    // This can happen in the case of a void pointer.
407  }
408
409  if (TypeExportError) {
410    slangAssert(Context->getDiagnostics()->hasErrorOccurred() &&
411                "Error exporting type but no diagnostic message issued!");
412    return nullptr;
413  }
414
415  return FE;
416}
417
418RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
419  slangAssert(Context);
420  llvm::StringRef Name = "root";
421  RSExportForEach *FE = new RSExportForEach(Context, Name);
422  FE->mDummyRoot = true;
423  return FE;
424}
425
426bool RSExportForEach::isRSForEachFunc(unsigned int targetAPI,
427                                      const clang::FunctionDecl *FD) {
428  if (!FD) {
429    return false;
430  }
431
432  // Anything tagged as a kernel("") is definitely used with ForEach.
433  if (auto *Kernel = FD->getAttr<clang::KernelAttr>()) {
434    return Kernel->getKernelKind().empty();
435  }
436
437  if (RSSpecialFunc::isGraphicsRootRSFunc(targetAPI, FD)) {
438    return false;
439  }
440
441  // Check if first parameter is a pointer (which is required for ForEach).
442  unsigned int numParams = FD->getNumParams();
443
444  if (numParams > 0) {
445    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
446    clang::QualType QT = PVD->getType().getCanonicalType();
447
448    if (QT->isPointerType()) {
449      return true;
450    }
451
452    // Any non-graphics root() is automatically a ForEach candidate.
453    // At this point, however, we know that it is not going to be a valid
454    // compute root() function (due to not having a pointer parameter). We
455    // still want to return true here, so that we can issue appropriate
456    // diagnostics.
457    if (isRootRSFunc(FD)) {
458      return true;
459    }
460  }
461
462  return false;
463}
464
465unsigned RSExportForEach::getNumInputs(unsigned int targetAPI,
466                                       const clang::FunctionDecl *FD) {
467  unsigned numInputs = 0;
468  for (const clang::ParmVarDecl* param : FD->params()) {
469    if (!isSpecialKernelParameter(param->getName())) {
470      numInputs++;
471    }
472  }
473
474  return numInputs;
475}
476
477}  // namespace slang
478