slang_rs_export_foreach.cpp revision eae0b7ad0195360b0afc37d51553f2917f1aa365
1/*
2 * Copyright 2011-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Attr.h"
23#include "clang/AST/Decl.h"
24#include "clang/AST/TypeLoc.h"
25
26#include "llvm/IR/DerivedTypes.h"
27
28#include "bcinfo/MetadataExtractor.h"
29
30#include "slang_assert.h"
31#include "slang_rs_context.h"
32#include "slang_rs_export_type.h"
33#include "slang_rs_special_func.h"
34#include "slang_version.h"
35
36namespace {
37
38const size_t RS_KERNEL_INPUT_LIMIT = 8; // see frameworks/base/libs/rs/cpu_ref/rsCpuCoreRuntime.h
39
40enum SpecialParameterKind {
41  SPK_INT,  // 'int' or 'unsigned int'
42  SPK_CTXT, // rs_kernel_context
43};
44
45struct SpecialParameter {
46  const char *name;
47  bcinfo::MetadataSignatureBitval bitval;
48  SpecialParameterKind kind;
49  SlangTargetAPI minAPI;
50};
51
52// Table entries are in the order parameters must occur in a kernel parameter list.
53const SpecialParameter specialParameterTable[] = {
54  { "ctxt", bcinfo::MD_SIG_Ctxt, SPK_CTXT, SLANG_23_TARGET_API },
55  { "x", bcinfo::MD_SIG_X, SPK_INT, SLANG_MINIMUM_TARGET_API },
56  { "y", bcinfo::MD_SIG_Y, SPK_INT, SLANG_MINIMUM_TARGET_API },
57  { "z", bcinfo::MD_SIG_Z, SPK_INT, SLANG_23_TARGET_API },
58  { nullptr, bcinfo::MD_SIG_None, SPK_INT, SLANG_MINIMUM_TARGET_API }, // marks end of table
59};
60
61// If the specified name matches the name of an entry in
62// specialParameterTable, return the corresponding table index;
63// otherwise return -1.
64int lookupSpecialParameter(const llvm::StringRef name) {
65  for (int i = 0; specialParameterTable[i].name != nullptr; ++i)
66    if (name.equals(specialParameterTable[i].name))
67      return i;
68  return -1;
69}
70
71// Return a comma-separated list of names in specialParameterTable
72// that are available at the specified API level.
73std::string listSpecialParameters(unsigned int api) {
74  std::string ret;
75  bool first = true;
76  for (int i = 0; specialParameterTable[i].name != nullptr; ++i) {
77    if (specialParameterTable[i].minAPI > api)
78      continue;
79    if (first)
80      first = false;
81    else
82      ret += ", ";
83    ret += "'";
84    ret += specialParameterTable[i].name;
85    ret += "'";
86  }
87  return ret;
88}
89
90bool isRootRSFunc(const clang::FunctionDecl *FD) {
91  if (!FD) {
92    return false;
93  }
94  return FD->getName().equals("root");
95}
96
97} // end anonymous namespace
98
99namespace slang {
100
101// This function takes care of additional validation and construction of
102// parameters related to forEach_* reflection.
103bool RSExportForEach::validateAndConstructParams(
104    RSContext *Context, const clang::FunctionDecl *FD) {
105  slangAssert(Context && FD);
106  bool valid = true;
107
108  numParams = FD->getNumParams();
109
110  if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
111    // Before JellyBean, we allowed only one kernel per file.  It must be called "root".
112    if (!isRootRSFunc(FD)) {
113      Context->ReportError(FD->getLocation(),
114                           "Non-root compute kernel %0() is "
115                           "not supported in SDK levels %1-%2")
116          << FD->getName() << SLANG_MINIMUM_TARGET_API
117          << (SLANG_JB_TARGET_API - 1);
118      return false;
119    }
120  }
121
122  mResultType = FD->getReturnType().getCanonicalType();
123  // Compute kernel functions are defined differently when the
124  // "__attribute__((kernel))" is set.
125  if (FD->hasAttr<clang::KernelAttr>()) {
126    valid |= validateAndConstructKernelParams(Context, FD);
127  } else {
128    valid |= validateAndConstructOldStyleParams(Context, FD);
129  }
130
131  valid |= setSignatureMetadata(Context, FD);
132  return valid;
133}
134
135bool RSExportForEach::validateAndConstructOldStyleParams(
136    RSContext *Context, const clang::FunctionDecl *FD) {
137  slangAssert(Context && FD);
138  // If numParams is 0, we already marked this as a graphics root().
139  slangAssert(numParams > 0);
140
141  bool valid = true;
142
143  // Compute kernel functions of this style are required to return a void type.
144  clang::ASTContext &C = Context->getASTContext();
145  if (mResultType != C.VoidTy) {
146    Context->ReportError(FD->getLocation(),
147                         "Compute kernel %0() is required to return a "
148                         "void type")
149        << FD->getName();
150    valid = false;
151  }
152
153  // Validate remaining parameter types
154
155  size_t IndexOfFirstSpecialParameter = numParams;
156  valid |= validateSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
157
158  // Validate the non-special parameters, which should all be found before the
159  // first special parameter.
160  for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
161    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
162    clang::QualType QT = PVD->getType().getCanonicalType();
163
164    if (!QT->isPointerType()) {
165      Context->ReportError(PVD->getLocation(),
166                           "Compute kernel %0() cannot have non-pointer "
167                           "parameters besides (%1). Parameter '%2' is "
168                           "of type: '%3'")
169          << FD->getName() << listSpecialParameters(Context->getTargetAPI())
170          << PVD->getName() << PVD->getType().getAsString();
171      valid = false;
172      continue;
173    }
174
175    // The only non-const pointer should be out.
176    if (!QT->getPointeeType().isConstQualified()) {
177      if (mOut == nullptr) {
178        mOut = PVD;
179      } else {
180        Context->ReportError(PVD->getLocation(),
181                             "Compute kernel %0() can only have one non-const "
182                             "pointer parameter. Parameters '%1' and '%2' are "
183                             "both non-const.")
184            << FD->getName() << mOut->getName() << PVD->getName();
185        valid = false;
186      }
187    } else {
188      if (mIns.empty() && mOut == nullptr) {
189        mIns.push_back(PVD);
190      } else if (mUsrData == nullptr) {
191        mUsrData = PVD;
192      } else {
193        Context->ReportError(
194            PVD->getLocation(),
195            "Unexpected parameter '%0' for compute kernel %1()")
196            << PVD->getName() << FD->getName();
197        valid = false;
198      }
199    }
200  }
201
202  if (mIns.empty() && !mOut) {
203    Context->ReportError(FD->getLocation(),
204                         "Compute kernel %0() must have at least one "
205                         "parameter for in or out")
206        << FD->getName();
207    valid = false;
208  }
209
210  return valid;
211}
212
213bool RSExportForEach::validateAndConstructKernelParams(
214    RSContext *Context, const clang::FunctionDecl *FD) {
215  slangAssert(Context && FD);
216  bool valid = true;
217  clang::ASTContext &C = Context->getASTContext();
218
219  if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
220    Context->ReportError(FD->getLocation(),
221                         "Compute kernel %0() targeting SDK levels "
222                         "%1-%2 may not use pass-by-value with "
223                         "__attribute__((kernel))")
224        << FD->getName() << SLANG_MINIMUM_TARGET_API
225        << (SLANG_JB_MR1_TARGET_API - 1);
226    return false;
227  }
228
229  // Denote that we are indeed a pass-by-value kernel.
230  mIsKernelStyle = true;
231  mHasReturnType = (mResultType != C.VoidTy);
232
233  if (mResultType->isPointerType()) {
234    Context->ReportError(
235        FD->getTypeSpecStartLoc(),
236        "Compute kernel %0() cannot return a pointer type: '%1'")
237        << FD->getName() << mResultType.getAsString();
238    valid = false;
239  }
240
241  // Validate remaining parameter types
242
243  size_t IndexOfFirstSpecialParameter = numParams;
244  valid |= validateSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
245
246  // Validate the non-special parameters, which should all be found before the
247  // first special.
248  for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
249    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
250
251    /*
252     * FIXME: Change this to a test against an actual API version when the
253     *        multi-input feature is officially supported.
254     */
255    if (Context->getTargetAPI() == SLANG_DEVELOPMENT_TARGET_API || i == 0) {
256      if (i >= RS_KERNEL_INPUT_LIMIT) {
257        Context->ReportError(PVD->getLocation(),
258                             "Invalid parameter '%0' for compute kernel %1(). "
259                             "Kernels targeting SDK levels %2-%3 may not use "
260                             "more than %4 input parameters.") << PVD->getName() <<
261                             FD->getName() << SLANG_MINIMUM_TARGET_API <<
262                             SLANG_MAXIMUM_TARGET_API << int(RS_KERNEL_INPUT_LIMIT);
263
264      } else {
265        mIns.push_back(PVD);
266      }
267    } else {
268      Context->ReportError(PVD->getLocation(),
269                           "Invalid parameter '%0' for compute kernel %1(). "
270                           "Kernels targeting SDK levels %2-%3 may not use "
271                           "multiple input parameters.") << PVD->getName() <<
272                           FD->getName() << SLANG_MINIMUM_TARGET_API <<
273                           SLANG_MAXIMUM_TARGET_API;
274      valid = false;
275    }
276    clang::QualType QT = PVD->getType().getCanonicalType();
277    if (QT->isPointerType()) {
278      Context->ReportError(PVD->getLocation(),
279                           "Compute kernel %0() cannot have "
280                           "parameter '%1' of pointer type: '%2'")
281          << FD->getName() << PVD->getName() << PVD->getType().getAsString();
282      valid = false;
283    }
284  }
285
286  // Check that we have at least one allocation to use for dimensions.
287  if (valid && mIns.empty() && !mHasReturnType && Context->getTargetAPI() < SLANG_23_TARGET_API) {
288    Context->ReportError(FD->getLocation(),
289                         "Compute kernel %0() targeting SDK levels "
290                         "%1-%2 must have at least one "
291                         "input parameter or a non-void return "
292                         "type")
293        << FD->getName() << SLANG_MINIMUM_TARGET_API
294        << (SLANG_23_TARGET_API - 1);
295    valid = false;
296  }
297
298  return valid;
299}
300
301// Search for the optional special parameters.  Returns true if valid.   Also
302// sets *IndexOfFirstSpecialParameter to the index of the first special parameter, or
303// FD->getNumParams() if none are found.
304bool RSExportForEach::validateSpecialParameters(
305    RSContext *Context, const clang::FunctionDecl *FD,
306    size_t *IndexOfFirstSpecialParameter) {
307  slangAssert(IndexOfFirstSpecialParameter != nullptr);
308  slangAssert(mSpecialParameterSignatureMetadata == 0);
309  clang::ASTContext &C = Context->getASTContext();
310
311  // Find all special parameters if present.
312  int LastSpecialParameterIdx = -1;     // index into specialParameterTable
313  int FirstIntSpecialParameterIdx = -1; // index into specialParameterTable
314  clang::QualType FirstIntSpecialParameterType;
315  size_t NumParams = FD->getNumParams();
316  *IndexOfFirstSpecialParameter = NumParams;
317  bool valid = true;
318  for (size_t i = 0; i < NumParams; i++) {
319    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
320    llvm::StringRef ParamName = PVD->getName();
321    int SpecialParameterIdx = lookupSpecialParameter(ParamName);
322    if (SpecialParameterIdx >= 0) {
323      const SpecialParameter &SP = specialParameterTable[SpecialParameterIdx];
324      // We won't be invoked if two parameters of the same name are present.
325      slangAssert(!(mSpecialParameterSignatureMetadata & SP.bitval));
326
327      if (Context->getTargetAPI() < SP.minAPI) {
328        Context->ReportError(PVD->getLocation(),
329                             "Compute kernel %0() targeting SDK levels "
330                             "%1-%2 may not use parameter '%3'.")
331            << FD->getName()
332            << SLANG_MINIMUM_TARGET_API
333            << (SP.minAPI - 1)
334            << SP.name;
335        valid = false;
336      }
337
338      mSpecialParameterSignatureMetadata |= SP.bitval;
339      if (SpecialParameterIdx < LastSpecialParameterIdx) {
340        Context->ReportError(PVD->getLocation(),
341                             "In compute kernel %0(), parameter '%1' must "
342                             "be defined before parameter '%2'.")
343            << FD->getName()
344            << SP.name
345            << specialParameterTable[LastSpecialParameterIdx].name;
346        valid = false;
347      }
348      LastSpecialParameterIdx = SpecialParameterIdx;
349
350      // Ensure that all SPK_INT special parameters have the same type.
351      if (SP.kind == SPK_INT) {
352        clang::QualType SpecialParameterType = PVD->getType();
353        if (FirstIntSpecialParameterIdx >= 0) {
354          if (SpecialParameterType != FirstIntSpecialParameterType) {
355            Context->ReportError(PVD->getLocation(),
356                                 "Parameters '%0' and '%1' must be of the same type. "
357                                 "'%0' is of type '%2' while '%1' is of type '%3'.")
358                << specialParameterTable[FirstIntSpecialParameterIdx].name
359                << SP.name
360                << FirstIntSpecialParameterType.getAsString()
361                << SpecialParameterType.getAsString();
362            valid = false;
363          }
364        } else {
365          FirstIntSpecialParameterIdx = SpecialParameterIdx;
366          FirstIntSpecialParameterType = SpecialParameterType;
367        }
368      }
369    } else {
370      // It's not a special parameter.
371      if (*IndexOfFirstSpecialParameter < NumParams) {
372        Context->ReportError(PVD->getLocation(),
373                             "In compute kernel %0(), parameter '%1' cannot "
374                             "appear after any of the (%2) parameters.")
375            << FD->getName() << ParamName << listSpecialParameters(Context->getTargetAPI());
376        valid = false;
377      }
378      continue;
379    }
380    // Validate the data type of the special parameter.
381    switch (specialParameterTable[SpecialParameterIdx].kind) {
382      case SPK_INT: {
383        clang::QualType QT = PVD->getType().getCanonicalType();
384        clang::QualType UT = QT.getUnqualifiedType();
385        if (UT != C.UnsignedIntTy && UT != C.IntTy) {
386          Context->ReportError(PVD->getLocation(),
387                               "Parameter '%0' must be of type 'int' or "
388                               "'unsigned int'. It is of type '%1'.")
389              << ParamName << PVD->getType().getAsString();
390          valid = false;
391        }
392        break;
393      }
394      case SPK_CTXT: {
395        static const char ExpectedTypeNameMatch[] = "const struct rs_kernel_context_t *";
396        static const char ExpectedTypeNamePrint[] = "rs_kernel_context";
397        clang::QualType QT = PVD->getType().getCanonicalType();
398        clang::QualType UT = QT.getUnqualifiedType();
399        if (UT.getAsString() != ExpectedTypeNameMatch) {
400          Context->ReportError(PVD->getLocation(),
401                               "Parameter '%0' must be of type '%1'. "
402                               "It is of type '%2'.")
403              << ParamName << ExpectedTypeNamePrint << PVD->getType().getAsString();
404          valid = false;
405        }
406        break;
407      }
408      default:
409        slangAssert(!"Unexpected special parameter type");
410    }
411    // If this is the first time we find a special parameter, save it.
412    if (*IndexOfFirstSpecialParameter >= NumParams) {
413      *IndexOfFirstSpecialParameter = i;
414    }
415  }
416  return valid;
417}
418
419bool RSExportForEach::setSignatureMetadata(RSContext *Context,
420                                           const clang::FunctionDecl *FD) {
421  mSignatureMetadata = 0;
422  bool valid = true;
423
424  if (mIsKernelStyle) {
425    slangAssert(mOut == nullptr);
426    slangAssert(mUsrData == nullptr);
427  } else {
428    slangAssert(!mHasReturnType);
429  }
430
431  // Set up the bitwise metadata encoding for runtime argument passing.
432  const bool HasOut = mOut || mHasReturnType;
433  mSignatureMetadata |= (hasIns() ?       bcinfo::MD_SIG_In     : 0);
434  mSignatureMetadata |= (HasOut ?         bcinfo::MD_SIG_Out    : 0);
435  mSignatureMetadata |= (mUsrData ?       bcinfo::MD_SIG_Usr    : 0);
436  mSignatureMetadata |= (mIsKernelStyle ? bcinfo::MD_SIG_Kernel : 0);  // pass-by-value
437  mSignatureMetadata |= mSpecialParameterSignatureMetadata;
438
439  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
440    // APIs before ICS cannot skip between parameters. It is ok, however, for
441    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
442    if (mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
443                               bcinfo::MD_SIG_X | bcinfo::MD_SIG_Y) &&
444        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
445                               bcinfo::MD_SIG_X) &&
446        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr) &&
447        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out) &&
448        mSignatureMetadata != (bcinfo::MD_SIG_In)) {
449      Context->ReportError(FD->getLocation(),
450                           "Compute kernel %0() targeting SDK levels "
451                           "%1-%2 may not skip parameters")
452          << FD->getName() << SLANG_MINIMUM_TARGET_API
453          << (SLANG_ICS_TARGET_API - 1);
454      valid = false;
455    }
456  }
457  return valid;
458}
459
460RSExportForEach *RSExportForEach::Create(RSContext *Context,
461                                         const clang::FunctionDecl *FD) {
462  slangAssert(Context && FD);
463  llvm::StringRef Name = FD->getName();
464  RSExportForEach *FE;
465
466  slangAssert(!Name.empty() && "Function must have a name");
467
468  FE = new RSExportForEach(Context, Name);
469
470  if (!FE->validateAndConstructParams(Context, FD)) {
471    return nullptr;
472  }
473
474  clang::ASTContext &Ctx = Context->getASTContext();
475
476  std::string Id = CreateDummyName("helper_foreach_param", FE->getName());
477
478  // Extract the usrData parameter (if we have one)
479  if (FE->mUsrData) {
480    const clang::ParmVarDecl *PVD = FE->mUsrData;
481    clang::QualType QT = PVD->getType().getCanonicalType();
482    slangAssert(QT->isPointerType() &&
483                QT->getPointeeType().isConstQualified());
484
485    const clang::ASTContext &C = Context->getASTContext();
486    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
487        C.VoidTy) {
488      // In the case of using const void*, we can't reflect an appopriate
489      // Java type, so we fall back to just reflecting the ain/aout parameters
490      FE->mUsrData = nullptr;
491    } else {
492      clang::RecordDecl *RD =
493          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
494                                    Ctx.getTranslationUnitDecl(),
495                                    clang::SourceLocation(),
496                                    clang::SourceLocation(),
497                                    &Ctx.Idents.get(Id));
498
499      clang::FieldDecl *FD =
500          clang::FieldDecl::Create(Ctx,
501                                   RD,
502                                   clang::SourceLocation(),
503                                   clang::SourceLocation(),
504                                   PVD->getIdentifier(),
505                                   QT->getPointeeType(),
506                                   nullptr,
507                                   /* BitWidth = */ nullptr,
508                                   /* Mutable = */ false,
509                                   /* HasInit = */ clang::ICIS_NoInit);
510      RD->addDecl(FD);
511      RD->completeDefinition();
512
513      // Create an export type iff we have a valid usrData type
514      clang::QualType T = Ctx.getTagDeclType(RD);
515      slangAssert(!T.isNull());
516
517      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
518
519      slangAssert(ET && "Failed to export a kernel");
520
521      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
522                  "Parameter packet must be a record");
523
524      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
525    }
526  }
527
528  if (FE->hasIns()) {
529
530    for (InIter BI = FE->mIns.begin(), EI = FE->mIns.end(); BI != EI; BI++) {
531      const clang::Type *T = (*BI)->getType().getCanonicalType().getTypePtr();
532      RSExportType *InExportType = RSExportType::Create(Context, T);
533
534      if (FE->mIsKernelStyle) {
535        slangAssert(InExportType != nullptr);
536      }
537
538      FE->mInTypes.push_back(InExportType);
539    }
540  }
541
542  if (FE->mIsKernelStyle && FE->mHasReturnType) {
543    const clang::Type *T = FE->mResultType.getTypePtr();
544    FE->mOutType = RSExportType::Create(Context, T);
545    slangAssert(FE->mOutType);
546  } else if (FE->mOut) {
547    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
548    FE->mOutType = RSExportType::Create(Context, T);
549  }
550
551  return FE;
552}
553
554RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
555  slangAssert(Context);
556  llvm::StringRef Name = "root";
557  RSExportForEach *FE = new RSExportForEach(Context, Name);
558  FE->mDummyRoot = true;
559  return FE;
560}
561
562bool RSExportForEach::isRSForEachFunc(unsigned int targetAPI,
563                                      const clang::FunctionDecl *FD) {
564  slangAssert(FD);
565
566  // Anything tagged as a kernel is definitely used with ForEach.
567  if (FD->hasAttr<clang::KernelAttr>()) {
568    return true;
569  }
570
571  if (RSSpecialFunc::isGraphicsRootRSFunc(targetAPI, FD)) {
572    return false;
573  }
574
575  // Check if first parameter is a pointer (which is required for ForEach).
576  unsigned int numParams = FD->getNumParams();
577
578  if (numParams > 0) {
579    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
580    clang::QualType QT = PVD->getType().getCanonicalType();
581
582    if (QT->isPointerType()) {
583      return true;
584    }
585
586    // Any non-graphics root() is automatically a ForEach candidate.
587    // At this point, however, we know that it is not going to be a valid
588    // compute root() function (due to not having a pointer parameter). We
589    // still want to return true here, so that we can issue appropriate
590    // diagnostics.
591    if (isRootRSFunc(FD)) {
592      return true;
593    }
594  }
595
596  return false;
597}
598
599}  // namespace slang
600