slang_rs_export_foreach.cpp revision bd0a7ddceac6c135ea975cefbac73877a1f9dae7
1/*
2 * Copyright 2011-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Attr.h"
23#include "clang/AST/Decl.h"
24#include "clang/AST/TypeLoc.h"
25
26#include "llvm/IR/DerivedTypes.h"
27
28#include "bcinfo/MetadataExtractor.h"
29
30#include "slang_assert.h"
31#include "slang_rs_context.h"
32#include "slang_rs_export_type.h"
33#include "slang_rs_special_func.h"
34#include "slang_version.h"
35
36namespace {
37
38const size_t RS_KERNEL_INPUT_LIMIT = 8; // see frameworks/base/libs/rs/cpu_ref/rsCpuCoreRuntime.h
39
40enum SpecialParameterKind {
41  SPK_INT,  // 'int' or 'unsigned int'
42  SPK_CTXT, // rs_kernel_context
43};
44
45struct SpecialParameter {
46  const char *name;
47  bcinfo::MetadataSignatureBitval bitval;
48  SpecialParameterKind kind;
49  SlangTargetAPI minAPI;
50};
51
52// Table entries are in the order parameters must occur in a kernel parameter list.
53const SpecialParameter specialParameterTable[] = {
54  { "ctxt", bcinfo::MD_SIG_Ctxt, SPK_CTXT, SLANG_M_TARGET_API },
55  { "x", bcinfo::MD_SIG_X, SPK_INT, SLANG_MINIMUM_TARGET_API },
56  { "y", bcinfo::MD_SIG_Y, SPK_INT, SLANG_MINIMUM_TARGET_API },
57  { "z", bcinfo::MD_SIG_Z, SPK_INT, SLANG_M_TARGET_API },
58  { nullptr, bcinfo::MD_SIG_None, SPK_INT, SLANG_MINIMUM_TARGET_API }, // marks end of table
59};
60
61// If the specified name matches the name of an entry in
62// specialParameterTable, return the corresponding table index;
63// otherwise return -1.
64int lookupSpecialParameter(const llvm::StringRef name) {
65  for (int i = 0; specialParameterTable[i].name != nullptr; ++i)
66    if (name.equals(specialParameterTable[i].name))
67      return i;
68  return -1;
69}
70
71// Return a comma-separated list of names in specialParameterTable
72// that are available at the specified API level.
73std::string listSpecialParameters(unsigned int api) {
74  std::string ret;
75  bool first = true;
76  for (int i = 0; specialParameterTable[i].name != nullptr; ++i) {
77    if (specialParameterTable[i].minAPI > api)
78      continue;
79    if (first)
80      first = false;
81    else
82      ret += ", ";
83    ret += "'";
84    ret += specialParameterTable[i].name;
85    ret += "'";
86  }
87  return ret;
88}
89
90bool isRootRSFunc(const clang::FunctionDecl *FD) {
91  if (!FD) {
92    return false;
93  }
94  return FD->getName().equals("root");
95}
96
97} // end anonymous namespace
98
99namespace slang {
100
101// This function takes care of additional validation and construction of
102// parameters related to forEach_* reflection.
103bool RSExportForEach::validateAndConstructParams(
104    RSContext *Context, const clang::FunctionDecl *FD) {
105  slangAssert(Context && FD);
106  bool valid = true;
107
108  numParams = FD->getNumParams();
109
110  if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
111    // Before JellyBean, we allowed only one kernel per file.  It must be called "root".
112    if (!isRootRSFunc(FD)) {
113      Context->ReportError(FD->getLocation(),
114                           "Non-root compute kernel %0() is "
115                           "not supported in SDK levels %1-%2")
116          << FD->getName() << SLANG_MINIMUM_TARGET_API
117          << (SLANG_JB_TARGET_API - 1);
118      return false;
119    }
120  }
121
122  mResultType = FD->getReturnType().getCanonicalType();
123  // Compute kernel functions are defined differently when the
124  // "__attribute__((kernel))" is set.
125  if (FD->hasAttr<clang::KernelAttr>()) {
126    valid |= validateAndConstructKernelParams(Context, FD);
127  } else {
128    valid |= validateAndConstructOldStyleParams(Context, FD);
129  }
130
131  valid |= setSignatureMetadata(Context, FD);
132  return valid;
133}
134
135bool RSExportForEach::validateAndConstructOldStyleParams(
136    RSContext *Context, const clang::FunctionDecl *FD) {
137  slangAssert(Context && FD);
138  // If numParams is 0, we already marked this as a graphics root().
139  slangAssert(numParams > 0);
140
141  bool valid = true;
142
143  // Compute kernel functions of this style are required to return a void type.
144  clang::ASTContext &C = Context->getASTContext();
145  if (mResultType != C.VoidTy) {
146    Context->ReportError(FD->getLocation(),
147                         "Compute kernel %0() is required to return a "
148                         "void type")
149        << FD->getName();
150    valid = false;
151  }
152
153  // Validate remaining parameter types
154
155  size_t IndexOfFirstSpecialParameter = numParams;
156  valid |= validateSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
157
158  // Validate the non-special parameters, which should all be found before the
159  // first special parameter.
160  for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
161    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
162    clang::QualType QT = PVD->getType().getCanonicalType();
163
164    if (!QT->isPointerType()) {
165      Context->ReportError(PVD->getLocation(),
166                           "Compute kernel %0() cannot have non-pointer "
167                           "parameters besides (%1). Parameter '%2' is "
168                           "of type: '%3'")
169          << FD->getName() << listSpecialParameters(Context->getTargetAPI())
170          << PVD->getName() << PVD->getType().getAsString();
171      valid = false;
172      continue;
173    }
174
175    // The only non-const pointer should be out.
176    if (!QT->getPointeeType().isConstQualified()) {
177      if (mOut == nullptr) {
178        mOut = PVD;
179      } else {
180        Context->ReportError(PVD->getLocation(),
181                             "Compute kernel %0() can only have one non-const "
182                             "pointer parameter. Parameters '%1' and '%2' are "
183                             "both non-const.")
184            << FD->getName() << mOut->getName() << PVD->getName();
185        valid = false;
186      }
187    } else {
188      if (mIns.empty() && mOut == nullptr) {
189        mIns.push_back(PVD);
190      } else if (mUsrData == nullptr) {
191        mUsrData = PVD;
192      } else {
193        Context->ReportError(
194            PVD->getLocation(),
195            "Unexpected parameter '%0' for compute kernel %1()")
196            << PVD->getName() << FD->getName();
197        valid = false;
198      }
199    }
200  }
201
202  if (mIns.empty() && !mOut) {
203    Context->ReportError(FD->getLocation(),
204                         "Compute kernel %0() must have at least one "
205                         "parameter for in or out")
206        << FD->getName();
207    valid = false;
208  }
209
210  return valid;
211}
212
213bool RSExportForEach::validateAndConstructKernelParams(
214    RSContext *Context, const clang::FunctionDecl *FD) {
215  slangAssert(Context && FD);
216  bool valid = true;
217  clang::ASTContext &C = Context->getASTContext();
218
219  if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
220    Context->ReportError(FD->getLocation(),
221                         "Compute kernel %0() targeting SDK levels "
222                         "%1-%2 may not use pass-by-value with "
223                         "__attribute__((kernel))")
224        << FD->getName() << SLANG_MINIMUM_TARGET_API
225        << (SLANG_JB_MR1_TARGET_API - 1);
226    return false;
227  }
228
229  // Denote that we are indeed a pass-by-value kernel.
230  mIsKernelStyle = true;
231  mHasReturnType = (mResultType != C.VoidTy);
232
233  if (mResultType->isPointerType()) {
234    Context->ReportError(
235        FD->getTypeSpecStartLoc(),
236        "Compute kernel %0() cannot return a pointer type: '%1'")
237        << FD->getName() << mResultType.getAsString();
238    valid = false;
239  }
240
241  // Validate remaining parameter types
242
243  size_t IndexOfFirstSpecialParameter = numParams;
244  valid |= validateSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
245
246  // Validate the non-special parameters, which should all be found before the
247  // first special.
248  for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
249    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
250
251    if (Context->getTargetAPI() >= SLANG_M_TARGET_API || i == 0) {
252      if (i >= RS_KERNEL_INPUT_LIMIT) {
253        Context->ReportError(PVD->getLocation(),
254                             "Invalid parameter '%0' for compute kernel %1(). "
255                             "Kernels targeting SDK levels %2+ may not use "
256                             "more than %3 input parameters.") << PVD->getName() <<
257                             FD->getName() << SLANG_M_TARGET_API <<
258                             int(RS_KERNEL_INPUT_LIMIT);
259
260      } else {
261        mIns.push_back(PVD);
262      }
263    } else {
264      Context->ReportError(PVD->getLocation(),
265                           "Invalid parameter '%0' for compute kernel %1(). "
266                           "Kernels targeting SDK levels %2-%3 may not use "
267                           "multiple input parameters.") << PVD->getName() <<
268                           FD->getName() << SLANG_MINIMUM_TARGET_API <<
269                           (SLANG_M_TARGET_API - 1);
270      valid = false;
271    }
272    clang::QualType QT = PVD->getType().getCanonicalType();
273    if (QT->isPointerType()) {
274      Context->ReportError(PVD->getLocation(),
275                           "Compute kernel %0() cannot have "
276                           "parameter '%1' of pointer type: '%2'")
277          << FD->getName() << PVD->getName() << PVD->getType().getAsString();
278      valid = false;
279    }
280  }
281
282  // Check that we have at least one allocation to use for dimensions.
283  if (valid && mIns.empty() && !mHasReturnType && Context->getTargetAPI() < SLANG_M_TARGET_API) {
284    Context->ReportError(FD->getLocation(),
285                         "Compute kernel %0() targeting SDK levels "
286                         "%1-%2 must have at least one "
287                         "input parameter or a non-void return "
288                         "type")
289        << FD->getName() << SLANG_MINIMUM_TARGET_API
290        << (SLANG_M_TARGET_API - 1);
291    valid = false;
292  }
293
294  return valid;
295}
296
297// Search for the optional special parameters.  Returns true if valid.   Also
298// sets *IndexOfFirstSpecialParameter to the index of the first special parameter, or
299// FD->getNumParams() if none are found.
300bool RSExportForEach::validateSpecialParameters(
301    RSContext *Context, const clang::FunctionDecl *FD,
302    size_t *IndexOfFirstSpecialParameter) {
303  slangAssert(IndexOfFirstSpecialParameter != nullptr);
304  slangAssert(mSpecialParameterSignatureMetadata == 0);
305  clang::ASTContext &C = Context->getASTContext();
306
307  // Find all special parameters if present.
308  int LastSpecialParameterIdx = -1;     // index into specialParameterTable
309  int FirstIntSpecialParameterIdx = -1; // index into specialParameterTable
310  clang::QualType FirstIntSpecialParameterType;
311  size_t NumParams = FD->getNumParams();
312  *IndexOfFirstSpecialParameter = NumParams;
313  bool valid = true;
314  for (size_t i = 0; i < NumParams; i++) {
315    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
316    llvm::StringRef ParamName = PVD->getName();
317    int SpecialParameterIdx = lookupSpecialParameter(ParamName);
318    if (SpecialParameterIdx >= 0) {
319      const SpecialParameter &SP = specialParameterTable[SpecialParameterIdx];
320      // We won't be invoked if two parameters of the same name are present.
321      slangAssert(!(mSpecialParameterSignatureMetadata & SP.bitval));
322
323      if (Context->getTargetAPI() < SP.minAPI) {
324        Context->ReportError(PVD->getLocation(),
325                             "Compute kernel %0() targeting SDK levels "
326                             "%1-%2 may not use parameter '%3'.")
327            << FD->getName()
328            << SLANG_MINIMUM_TARGET_API
329            << (SP.minAPI - 1)
330            << SP.name;
331        valid = false;
332      }
333
334      mSpecialParameterSignatureMetadata |= SP.bitval;
335      if (SpecialParameterIdx < LastSpecialParameterIdx) {
336        Context->ReportError(PVD->getLocation(),
337                             "In compute kernel %0(), parameter '%1' must "
338                             "be defined before parameter '%2'.")
339            << FD->getName()
340            << SP.name
341            << specialParameterTable[LastSpecialParameterIdx].name;
342        valid = false;
343      }
344      LastSpecialParameterIdx = SpecialParameterIdx;
345
346      // Ensure that all SPK_INT special parameters have the same type.
347      if (SP.kind == SPK_INT) {
348        clang::QualType SpecialParameterType = PVD->getType();
349        if (FirstIntSpecialParameterIdx >= 0) {
350          if (SpecialParameterType != FirstIntSpecialParameterType) {
351            Context->ReportError(PVD->getLocation(),
352                                 "Parameters '%0' and '%1' must be of the same type. "
353                                 "'%0' is of type '%2' while '%1' is of type '%3'.")
354                << specialParameterTable[FirstIntSpecialParameterIdx].name
355                << SP.name
356                << FirstIntSpecialParameterType.getAsString()
357                << SpecialParameterType.getAsString();
358            valid = false;
359          }
360        } else {
361          FirstIntSpecialParameterIdx = SpecialParameterIdx;
362          FirstIntSpecialParameterType = SpecialParameterType;
363        }
364      }
365    } else {
366      // It's not a special parameter.
367      if (*IndexOfFirstSpecialParameter < NumParams) {
368        Context->ReportError(PVD->getLocation(),
369                             "In compute kernel %0(), parameter '%1' cannot "
370                             "appear after any of the (%2) parameters.")
371            << FD->getName() << ParamName << listSpecialParameters(Context->getTargetAPI());
372        valid = false;
373      }
374      continue;
375    }
376    // Validate the data type of the special parameter.
377    switch (specialParameterTable[SpecialParameterIdx].kind) {
378      case SPK_INT: {
379        clang::QualType QT = PVD->getType().getCanonicalType();
380        clang::QualType UT = QT.getUnqualifiedType();
381        if (UT != C.UnsignedIntTy && UT != C.IntTy) {
382          Context->ReportError(PVD->getLocation(),
383                               "Parameter '%0' must be of type 'int' or "
384                               "'unsigned int'. It is of type '%1'.")
385              << ParamName << PVD->getType().getAsString();
386          valid = false;
387        }
388        break;
389      }
390      case SPK_CTXT: {
391        static const char ExpectedTypeNameMatch[] = "const struct rs_kernel_context_t *";
392        static const char ExpectedTypeNamePrint[] = "rs_kernel_context";
393        clang::QualType QT = PVD->getType().getCanonicalType();
394        clang::QualType UT = QT.getUnqualifiedType();
395        if (UT.getAsString() != ExpectedTypeNameMatch) {
396          Context->ReportError(PVD->getLocation(),
397                               "Parameter '%0' must be of type '%1'. "
398                               "It is of type '%2'.")
399              << ParamName << ExpectedTypeNamePrint << PVD->getType().getAsString();
400          valid = false;
401        }
402        break;
403      }
404      default:
405        slangAssert(!"Unexpected special parameter type");
406    }
407    // If this is the first time we find a special parameter, save it.
408    if (*IndexOfFirstSpecialParameter >= NumParams) {
409      *IndexOfFirstSpecialParameter = i;
410    }
411  }
412  return valid;
413}
414
415bool RSExportForEach::setSignatureMetadata(RSContext *Context,
416                                           const clang::FunctionDecl *FD) {
417  mSignatureMetadata = 0;
418  bool valid = true;
419
420  if (mIsKernelStyle) {
421    slangAssert(mOut == nullptr);
422    slangAssert(mUsrData == nullptr);
423  } else {
424    slangAssert(!mHasReturnType);
425  }
426
427  // Set up the bitwise metadata encoding for runtime argument passing.
428  const bool HasOut = mOut || mHasReturnType;
429  mSignatureMetadata |= (hasIns() ?       bcinfo::MD_SIG_In     : 0);
430  mSignatureMetadata |= (HasOut ?         bcinfo::MD_SIG_Out    : 0);
431  mSignatureMetadata |= (mUsrData ?       bcinfo::MD_SIG_Usr    : 0);
432  mSignatureMetadata |= (mIsKernelStyle ? bcinfo::MD_SIG_Kernel : 0);  // pass-by-value
433  mSignatureMetadata |= mSpecialParameterSignatureMetadata;
434
435  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
436    // APIs before ICS cannot skip between parameters. It is ok, however, for
437    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
438    if (mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
439                               bcinfo::MD_SIG_X | bcinfo::MD_SIG_Y) &&
440        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
441                               bcinfo::MD_SIG_X) &&
442        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr) &&
443        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out) &&
444        mSignatureMetadata != (bcinfo::MD_SIG_In)) {
445      Context->ReportError(FD->getLocation(),
446                           "Compute kernel %0() targeting SDK levels "
447                           "%1-%2 may not skip parameters")
448          << FD->getName() << SLANG_MINIMUM_TARGET_API
449          << (SLANG_ICS_TARGET_API - 1);
450      valid = false;
451    }
452  }
453  return valid;
454}
455
456RSExportForEach *RSExportForEach::Create(RSContext *Context,
457                                         const clang::FunctionDecl *FD) {
458  slangAssert(Context && FD);
459  llvm::StringRef Name = FD->getName();
460  RSExportForEach *FE;
461
462  slangAssert(!Name.empty() && "Function must have a name");
463
464  FE = new RSExportForEach(Context, Name);
465
466  if (!FE->validateAndConstructParams(Context, FD)) {
467    return nullptr;
468  }
469
470  clang::ASTContext &Ctx = Context->getASTContext();
471
472  std::string Id = CreateDummyName("helper_foreach_param", FE->getName());
473
474  // Extract the usrData parameter (if we have one)
475  if (FE->mUsrData) {
476    const clang::ParmVarDecl *PVD = FE->mUsrData;
477    clang::QualType QT = PVD->getType().getCanonicalType();
478    slangAssert(QT->isPointerType() &&
479                QT->getPointeeType().isConstQualified());
480
481    const clang::ASTContext &C = Context->getASTContext();
482    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
483        C.VoidTy) {
484      // In the case of using const void*, we can't reflect an appopriate
485      // Java type, so we fall back to just reflecting the ain/aout parameters
486      FE->mUsrData = nullptr;
487    } else {
488      clang::RecordDecl *RD =
489          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
490                                    Ctx.getTranslationUnitDecl(),
491                                    clang::SourceLocation(),
492                                    clang::SourceLocation(),
493                                    &Ctx.Idents.get(Id));
494
495      clang::FieldDecl *FD =
496          clang::FieldDecl::Create(Ctx,
497                                   RD,
498                                   clang::SourceLocation(),
499                                   clang::SourceLocation(),
500                                   PVD->getIdentifier(),
501                                   QT->getPointeeType(),
502                                   nullptr,
503                                   /* BitWidth = */ nullptr,
504                                   /* Mutable = */ false,
505                                   /* HasInit = */ clang::ICIS_NoInit);
506      RD->addDecl(FD);
507      RD->completeDefinition();
508
509      // Create an export type iff we have a valid usrData type
510      clang::QualType T = Ctx.getTagDeclType(RD);
511      slangAssert(!T.isNull());
512
513      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
514
515      slangAssert(ET && "Failed to export a kernel");
516
517      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
518                  "Parameter packet must be a record");
519
520      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
521    }
522  }
523
524  // Construct type information about inputs and outputs. Return null when
525  // there is an error exporting types.
526
527  bool TypeExportError = false;
528
529  if (FE->hasIns()) {
530    for (InIter BI = FE->mIns.begin(), EI = FE->mIns.end(); BI != EI; BI++) {
531      const clang::Type *T = (*BI)->getType().getCanonicalType().getTypePtr();
532      RSExportType *InExportType = RSExportType::Create(Context, T);
533
534      // It is not an error if we don't export an input type for legacy
535      // kernels. This can happen in the case of a void pointer.
536      if (FE->mIsKernelStyle && !InExportType) {
537        TypeExportError = true;
538      }
539
540      FE->mInTypes.push_back(InExportType);
541    }
542  }
543
544  if (FE->mIsKernelStyle && FE->mHasReturnType) {
545    const clang::Type *ReturnType = FE->mResultType.getTypePtr();
546    FE->mOutType = RSExportType::Create(Context, ReturnType);
547    TypeExportError |= !FE->mOutType;
548  } else if (FE->mOut) {
549    const clang::Type *OutType =
550        FE->mOut->getType().getCanonicalType().getTypePtr();
551    FE->mOutType = RSExportType::Create(Context, OutType);
552    // It is not an error if we don't export an output type.
553    // This can happen in the case of a void pointer.
554  }
555
556  if (TypeExportError) {
557    slangAssert(Context->getDiagnostics()->hasErrorOccurred() &&
558                "Error exporting type but no diagnostic message issued!");
559    return nullptr;
560  }
561
562  return FE;
563}
564
565RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
566  slangAssert(Context);
567  llvm::StringRef Name = "root";
568  RSExportForEach *FE = new RSExportForEach(Context, Name);
569  FE->mDummyRoot = true;
570  return FE;
571}
572
573bool RSExportForEach::isRSForEachFunc(unsigned int targetAPI,
574                                      const clang::FunctionDecl *FD) {
575  slangAssert(FD);
576
577  // Anything tagged as a kernel is definitely used with ForEach.
578  if (FD->hasAttr<clang::KernelAttr>()) {
579    return true;
580  }
581
582  if (RSSpecialFunc::isGraphicsRootRSFunc(targetAPI, FD)) {
583    return false;
584  }
585
586  // Check if first parameter is a pointer (which is required for ForEach).
587  unsigned int numParams = FD->getNumParams();
588
589  if (numParams > 0) {
590    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
591    clang::QualType QT = PVD->getType().getCanonicalType();
592
593    if (QT->isPointerType()) {
594      return true;
595    }
596
597    // Any non-graphics root() is automatically a ForEach candidate.
598    // At this point, however, we know that it is not going to be a valid
599    // compute root() function (due to not having a pointer parameter). We
600    // still want to return true here, so that we can issue appropriate
601    // diagnostics.
602    if (isRootRSFunc(FD)) {
603      return true;
604    }
605  }
606
607  return false;
608}
609
610}  // namespace slang
611