slang_rs_export_foreach.cpp revision cd7d3128f41a59692e5c59a2b81969616579aae4
1/*
2 * Copyright 2011-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Attr.h"
23#include "clang/AST/Decl.h"
24#include "clang/AST/TypeLoc.h"
25
26#include "llvm/IR/DerivedTypes.h"
27
28#include "bcinfo/MetadataExtractor.h"
29
30#include "slang_assert.h"
31#include "slang_rs_context.h"
32#include "slang_rs_export_type.h"
33#include "slang_version.h"
34
35namespace {
36
37const size_t RS_KERNEL_INPUT_LIMIT = 8; // see frameworks/base/libs/rs/cpu_ref/rsCpuCoreRuntime.h
38
39enum SpecialParameterKind {
40  SPK_INT,  // 'int' or 'unsigned int'
41  SPK_CTXT, // rs_kernel_context
42};
43
44struct SpecialParameter {
45  const char *name;
46  bcinfo::MetadataSignatureBitval bitval;
47  SpecialParameterKind kind;
48  SlangTargetAPI minAPI;
49};
50
51// Table entries are in the order parameters must occur in a kernel parameter list.
52const SpecialParameter specialParameterTable[] = {
53  { "ctxt", bcinfo::MD_SIG_Ctxt, SPK_CTXT, SLANG_23_TARGET_API },
54  { "x", bcinfo::MD_SIG_X, SPK_INT, SLANG_MINIMUM_TARGET_API },
55  { "y", bcinfo::MD_SIG_Y, SPK_INT, SLANG_MINIMUM_TARGET_API },
56  { "z", bcinfo::MD_SIG_Z, SPK_INT, SLANG_23_TARGET_API },
57  { nullptr, bcinfo::MD_SIG_None, SPK_INT, SLANG_MINIMUM_TARGET_API }, // marks end of table
58};
59
60// If the specified name matches the name of an entry in
61// specialParameterTable, return the corresponding table index;
62// otherwise return -1.
63int lookupSpecialParameter(const llvm::StringRef name) {
64  for (int i = 0; specialParameterTable[i].name != nullptr; ++i)
65    if (name.equals(specialParameterTable[i].name))
66      return i;
67  return -1;
68}
69
70// Return a comma-separated list of names in specialParameterTable
71// that are available at the specified API level.
72std::string listSpecialParameters(unsigned int api) {
73  std::string ret;
74  bool first = true;
75  for (int i = 0; specialParameterTable[i].name != nullptr; ++i) {
76    if (specialParameterTable[i].minAPI > api)
77      continue;
78    if (first)
79      first = false;
80    else
81      ret += ", ";
82    ret += "'";
83    ret += specialParameterTable[i].name;
84    ret += "'";
85  }
86  return ret;
87}
88
89}
90
91namespace slang {
92
93// This function takes care of additional validation and construction of
94// parameters related to forEach_* reflection.
95bool RSExportForEach::validateAndConstructParams(
96    RSContext *Context, const clang::FunctionDecl *FD) {
97  slangAssert(Context && FD);
98  bool valid = true;
99
100  numParams = FD->getNumParams();
101
102  if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
103    // Before JellyBean, we allowed only one kernel per file.  It must be called "root".
104    if (!isRootRSFunc(FD)) {
105      Context->ReportError(FD->getLocation(),
106                           "Non-root compute kernel %0() is "
107                           "not supported in SDK levels %1-%2")
108          << FD->getName() << SLANG_MINIMUM_TARGET_API
109          << (SLANG_JB_TARGET_API - 1);
110      return false;
111    }
112  }
113
114  mResultType = FD->getReturnType().getCanonicalType();
115  // Compute kernel functions are defined differently when the
116  // "__attribute__((kernel))" is set.
117  if (FD->hasAttr<clang::KernelAttr>()) {
118    valid |= validateAndConstructKernelParams(Context, FD);
119  } else {
120    valid |= validateAndConstructOldStyleParams(Context, FD);
121  }
122
123  valid |= setSignatureMetadata(Context, FD);
124  return valid;
125}
126
127bool RSExportForEach::validateAndConstructOldStyleParams(
128    RSContext *Context, const clang::FunctionDecl *FD) {
129  slangAssert(Context && FD);
130  // If numParams is 0, we already marked this as a graphics root().
131  slangAssert(numParams > 0);
132
133  bool valid = true;
134
135  // Compute kernel functions of this style are required to return a void type.
136  clang::ASTContext &C = Context->getASTContext();
137  if (mResultType != C.VoidTy) {
138    Context->ReportError(FD->getLocation(),
139                         "Compute kernel %0() is required to return a "
140                         "void type")
141        << FD->getName();
142    valid = false;
143  }
144
145  // Validate remaining parameter types
146
147  size_t IndexOfFirstSpecialParameter = numParams;
148  valid |= validateSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
149
150  // Validate the non-special parameters, which should all be found before the
151  // first special parameter.
152  for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
153    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
154    clang::QualType QT = PVD->getType().getCanonicalType();
155
156    if (!QT->isPointerType()) {
157      Context->ReportError(PVD->getLocation(),
158                           "Compute kernel %0() cannot have non-pointer "
159                           "parameters besides (%1). Parameter '%2' is "
160                           "of type: '%3'")
161          << FD->getName() << listSpecialParameters(Context->getTargetAPI())
162          << PVD->getName() << PVD->getType().getAsString();
163      valid = false;
164      continue;
165    }
166
167    // The only non-const pointer should be out.
168    if (!QT->getPointeeType().isConstQualified()) {
169      if (mOut == nullptr) {
170        mOut = PVD;
171      } else {
172        Context->ReportError(PVD->getLocation(),
173                             "Compute kernel %0() can only have one non-const "
174                             "pointer parameter. Parameters '%1' and '%2' are "
175                             "both non-const.")
176            << FD->getName() << mOut->getName() << PVD->getName();
177        valid = false;
178      }
179    } else {
180      if (mIns.empty() && mOut == nullptr) {
181        mIns.push_back(PVD);
182      } else if (mUsrData == nullptr) {
183        mUsrData = PVD;
184      } else {
185        Context->ReportError(
186            PVD->getLocation(),
187            "Unexpected parameter '%0' for compute kernel %1()")
188            << PVD->getName() << FD->getName();
189        valid = false;
190      }
191    }
192  }
193
194  if (mIns.empty() && !mOut) {
195    Context->ReportError(FD->getLocation(),
196                         "Compute kernel %0() must have at least one "
197                         "parameter for in or out")
198        << FD->getName();
199    valid = false;
200  }
201
202  return valid;
203}
204
205bool RSExportForEach::validateAndConstructKernelParams(
206    RSContext *Context, const clang::FunctionDecl *FD) {
207  slangAssert(Context && FD);
208  bool valid = true;
209  clang::ASTContext &C = Context->getASTContext();
210
211  if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
212    Context->ReportError(FD->getLocation(),
213                         "Compute kernel %0() targeting SDK levels "
214                         "%1-%2 may not use pass-by-value with "
215                         "__attribute__((kernel))")
216        << FD->getName() << SLANG_MINIMUM_TARGET_API
217        << (SLANG_JB_MR1_TARGET_API - 1);
218    return false;
219  }
220
221  // Denote that we are indeed a pass-by-value kernel.
222  mIsKernelStyle = true;
223  mHasReturnType = (mResultType != C.VoidTy);
224
225  if (mResultType->isPointerType()) {
226    Context->ReportError(
227        FD->getTypeSpecStartLoc(),
228        "Compute kernel %0() cannot return a pointer type: '%1'")
229        << FD->getName() << mResultType.getAsString();
230    valid = false;
231  }
232
233  // Validate remaining parameter types
234
235  size_t IndexOfFirstSpecialParameter = numParams;
236  valid |= validateSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
237
238  // Validate the non-special parameters, which should all be found before the
239  // first special.
240  for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
241    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
242
243    /*
244     * FIXME: Change this to a test against an actual API version when the
245     *        multi-input feature is officially supported.
246     */
247    if (Context->getTargetAPI() == SLANG_DEVELOPMENT_TARGET_API || i == 0) {
248      if (i >= RS_KERNEL_INPUT_LIMIT) {
249        Context->ReportError(PVD->getLocation(),
250                             "Invalid parameter '%0' for compute kernel %1(). "
251                             "Kernels targeting SDK levels %2-%3 may not use "
252                             "more than %4 input parameters.") << PVD->getName() <<
253                             FD->getName() << SLANG_MINIMUM_TARGET_API <<
254                             SLANG_MAXIMUM_TARGET_API << int(RS_KERNEL_INPUT_LIMIT);
255
256      } else {
257        mIns.push_back(PVD);
258      }
259    } else {
260      Context->ReportError(PVD->getLocation(),
261                           "Invalid parameter '%0' for compute kernel %1(). "
262                           "Kernels targeting SDK levels %2-%3 may not use "
263                           "multiple input parameters.") << PVD->getName() <<
264                           FD->getName() << SLANG_MINIMUM_TARGET_API <<
265                           SLANG_MAXIMUM_TARGET_API;
266      valid = false;
267    }
268    clang::QualType QT = PVD->getType().getCanonicalType();
269    if (QT->isPointerType()) {
270      Context->ReportError(PVD->getLocation(),
271                           "Compute kernel %0() cannot have "
272                           "parameter '%1' of pointer type: '%2'")
273          << FD->getName() << PVD->getName() << PVD->getType().getAsString();
274      valid = false;
275    }
276  }
277
278  // Check that we have at least one allocation to use for dimensions.
279  if (valid && mIns.empty() && !mHasReturnType) {
280    Context->ReportError(FD->getLocation(),
281                         "Compute kernel %0() must have at least one "
282                         "input parameter or a non-void return "
283                         "type")
284        << FD->getName();
285    valid = false;
286  }
287
288  return valid;
289}
290
291// Search for the optional special parameters.  Returns true if valid.   Also
292// sets *IndexOfFirstSpecialParameter to the index of the first special parameter, or
293// FD->getNumParams() if none are found.
294bool RSExportForEach::validateSpecialParameters(
295    RSContext *Context, const clang::FunctionDecl *FD,
296    size_t *IndexOfFirstSpecialParameter) {
297  slangAssert(IndexOfFirstSpecialParameter != nullptr);
298  slangAssert(mSpecialParameterSignatureMetadata == 0);
299  clang::ASTContext &C = Context->getASTContext();
300
301  // Find all special parameters if present.
302  int LastSpecialParameterIdx = -1;     // index into specialParameterTable
303  int FirstIntSpecialParameterIdx = -1; // index into specialParameterTable
304  clang::QualType FirstIntSpecialParameterType;
305  size_t NumParams = FD->getNumParams();
306  *IndexOfFirstSpecialParameter = NumParams;
307  bool valid = true;
308  for (size_t i = 0; i < NumParams; i++) {
309    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
310    llvm::StringRef ParamName = PVD->getName();
311    int SpecialParameterIdx = lookupSpecialParameter(ParamName);
312    if (SpecialParameterIdx >= 0) {
313      const SpecialParameter &SP = specialParameterTable[SpecialParameterIdx];
314      // We won't be invoked if two parameters of the same name are present.
315      slangAssert(!(mSpecialParameterSignatureMetadata & SP.bitval));
316
317      if (Context->getTargetAPI() < SP.minAPI) {
318        Context->ReportError(PVD->getLocation(),
319                             "Compute kernel %0() targeting SDK levels "
320                             "%1-%2 may not use parameter '%3'.")
321            << FD->getName()
322            << SLANG_MINIMUM_TARGET_API
323            << (SP.minAPI - 1)
324            << SP.name;
325        valid = false;
326      }
327
328      mSpecialParameterSignatureMetadata |= SP.bitval;
329      if (SpecialParameterIdx < LastSpecialParameterIdx) {
330        Context->ReportError(PVD->getLocation(),
331                             "In compute kernel %0(), parameter '%1' must "
332                             "be defined before parameter '%2'.")
333            << FD->getName()
334            << SP.name
335            << specialParameterTable[LastSpecialParameterIdx].name;
336        valid = false;
337      }
338      LastSpecialParameterIdx = SpecialParameterIdx;
339
340      // Ensure that all SPK_INT special parameters have the same type.
341      if (SP.kind == SPK_INT) {
342        clang::QualType SpecialParameterType = PVD->getType();
343        if (FirstIntSpecialParameterIdx >= 0) {
344          if (SpecialParameterType != FirstIntSpecialParameterType) {
345            Context->ReportError(PVD->getLocation(),
346                                 "Parameters '%0' and '%1' must be of the same type. "
347                                 "'%0' is of type '%2' while '%1' is of type '%3'.")
348                << specialParameterTable[FirstIntSpecialParameterIdx].name
349                << SP.name
350                << FirstIntSpecialParameterType.getAsString()
351                << SpecialParameterType.getAsString();
352            valid = false;
353          }
354        } else {
355          FirstIntSpecialParameterIdx = SpecialParameterIdx;
356          FirstIntSpecialParameterType = SpecialParameterType;
357        }
358      }
359    } else {
360      // It's not a special parameter.
361      if (*IndexOfFirstSpecialParameter < NumParams) {
362        Context->ReportError(PVD->getLocation(),
363                             "In compute kernel %0(), parameter '%1' cannot "
364                             "appear after any of the (%2) parameters.")
365            << FD->getName() << ParamName << listSpecialParameters(Context->getTargetAPI());
366        valid = false;
367      }
368      continue;
369    }
370    // Validate the data type of the special parameter.
371    switch (specialParameterTable[SpecialParameterIdx].kind) {
372      case SPK_INT: {
373        clang::QualType QT = PVD->getType().getCanonicalType();
374        clang::QualType UT = QT.getUnqualifiedType();
375        if (UT != C.UnsignedIntTy && UT != C.IntTy) {
376          Context->ReportError(PVD->getLocation(),
377                               "Parameter '%0' must be of type 'int' or "
378                               "'unsigned int'. It is of type '%1'.")
379              << ParamName << PVD->getType().getAsString();
380          valid = false;
381        }
382        break;
383      }
384      case SPK_CTXT: {
385        static const char ExpectedTypeNameMatch[] = "const struct rs_kernel_context_t *";
386        static const char ExpectedTypeNamePrint[] = "rs_kernel_context";
387        clang::QualType QT = PVD->getType().getCanonicalType();
388        clang::QualType UT = QT.getUnqualifiedType();
389        if (UT.getAsString() != ExpectedTypeNameMatch) {
390          Context->ReportError(PVD->getLocation(),
391                               "Parameter '%0' must be of type '%1'. "
392                               "It is of type '%2'.")
393              << ParamName << ExpectedTypeNamePrint << PVD->getType().getAsString();
394          valid = false;
395        }
396        break;
397      }
398      default:
399        slangAssert(!"Unexpected special parameter type");
400    }
401    // If this is the first time we find a special parameter, save it.
402    if (*IndexOfFirstSpecialParameter >= NumParams) {
403      *IndexOfFirstSpecialParameter = i;
404    }
405  }
406  return valid;
407}
408
409bool RSExportForEach::setSignatureMetadata(RSContext *Context,
410                                           const clang::FunctionDecl *FD) {
411  mSignatureMetadata = 0;
412  bool valid = true;
413
414  if (mIsKernelStyle) {
415    slangAssert(mOut == nullptr);
416    slangAssert(mUsrData == nullptr);
417  } else {
418    slangAssert(!mHasReturnType);
419  }
420
421  // Set up the bitwise metadata encoding for runtime argument passing.
422  const bool HasOut = mOut || mHasReturnType;
423  mSignatureMetadata |= (hasIns() ?       bcinfo::MD_SIG_In     : 0);
424  mSignatureMetadata |= (HasOut ?         bcinfo::MD_SIG_Out    : 0);
425  mSignatureMetadata |= (mUsrData ?       bcinfo::MD_SIG_Usr    : 0);
426  mSignatureMetadata |= (mIsKernelStyle ? bcinfo::MD_SIG_Kernel : 0);  // pass-by-value
427  mSignatureMetadata |= mSpecialParameterSignatureMetadata;
428
429  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
430    // APIs before ICS cannot skip between parameters. It is ok, however, for
431    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
432    if (mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
433                               bcinfo::MD_SIG_X | bcinfo::MD_SIG_Y) &&
434        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
435                               bcinfo::MD_SIG_X) &&
436        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr) &&
437        mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out) &&
438        mSignatureMetadata != (bcinfo::MD_SIG_In)) {
439      Context->ReportError(FD->getLocation(),
440                           "Compute kernel %0() targeting SDK levels "
441                           "%1-%2 may not skip parameters")
442          << FD->getName() << SLANG_MINIMUM_TARGET_API
443          << (SLANG_ICS_TARGET_API - 1);
444      valid = false;
445    }
446  }
447  return valid;
448}
449
450RSExportForEach *RSExportForEach::Create(RSContext *Context,
451                                         const clang::FunctionDecl *FD) {
452  slangAssert(Context && FD);
453  llvm::StringRef Name = FD->getName();
454  RSExportForEach *FE;
455
456  slangAssert(!Name.empty() && "Function must have a name");
457
458  FE = new RSExportForEach(Context, Name);
459
460  if (!FE->validateAndConstructParams(Context, FD)) {
461    return nullptr;
462  }
463
464  clang::ASTContext &Ctx = Context->getASTContext();
465
466  std::string Id = CreateDummyName("helper_foreach_param", FE->getName());
467
468  // Extract the usrData parameter (if we have one)
469  if (FE->mUsrData) {
470    const clang::ParmVarDecl *PVD = FE->mUsrData;
471    clang::QualType QT = PVD->getType().getCanonicalType();
472    slangAssert(QT->isPointerType() &&
473                QT->getPointeeType().isConstQualified());
474
475    const clang::ASTContext &C = Context->getASTContext();
476    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
477        C.VoidTy) {
478      // In the case of using const void*, we can't reflect an appopriate
479      // Java type, so we fall back to just reflecting the ain/aout parameters
480      FE->mUsrData = nullptr;
481    } else {
482      clang::RecordDecl *RD =
483          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
484                                    Ctx.getTranslationUnitDecl(),
485                                    clang::SourceLocation(),
486                                    clang::SourceLocation(),
487                                    &Ctx.Idents.get(Id));
488
489      clang::FieldDecl *FD =
490          clang::FieldDecl::Create(Ctx,
491                                   RD,
492                                   clang::SourceLocation(),
493                                   clang::SourceLocation(),
494                                   PVD->getIdentifier(),
495                                   QT->getPointeeType(),
496                                   nullptr,
497                                   /* BitWidth = */ nullptr,
498                                   /* Mutable = */ false,
499                                   /* HasInit = */ clang::ICIS_NoInit);
500      RD->addDecl(FD);
501      RD->completeDefinition();
502
503      // Create an export type iff we have a valid usrData type
504      clang::QualType T = Ctx.getTagDeclType(RD);
505      slangAssert(!T.isNull());
506
507      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
508
509      if (ET == nullptr) {
510        fprintf(stderr, "Failed to export the function %s. There's at least "
511                        "one parameter whose type is not supported by the "
512                        "reflection\n", FE->getName().c_str());
513        return nullptr;
514      }
515
516      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
517                  "Parameter packet must be a record");
518
519      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
520    }
521  }
522
523  if (FE->hasIns()) {
524
525    for (InIter BI = FE->mIns.begin(), EI = FE->mIns.end(); BI != EI; BI++) {
526      const clang::Type *T = (*BI)->getType().getCanonicalType().getTypePtr();
527      RSExportType *InExportType = RSExportType::Create(Context, T);
528
529      if (FE->mIsKernelStyle) {
530        slangAssert(InExportType != nullptr);
531      }
532
533      FE->mInTypes.push_back(InExportType);
534    }
535  }
536
537  if (FE->mIsKernelStyle && FE->mHasReturnType) {
538    const clang::Type *T = FE->mResultType.getTypePtr();
539    FE->mOutType = RSExportType::Create(Context, T);
540    slangAssert(FE->mOutType);
541  } else if (FE->mOut) {
542    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
543    FE->mOutType = RSExportType::Create(Context, T);
544  }
545
546  return FE;
547}
548
549RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
550  slangAssert(Context);
551  llvm::StringRef Name = "root";
552  RSExportForEach *FE = new RSExportForEach(Context, Name);
553  FE->mDummyRoot = true;
554  return FE;
555}
556
557bool RSExportForEach::isGraphicsRootRSFunc(unsigned int targetAPI,
558                                           const clang::FunctionDecl *FD) {
559  if (FD->hasAttr<clang::KernelAttr>()) {
560    return false;
561  }
562
563  if (!isRootRSFunc(FD)) {
564    return false;
565  }
566
567  if (FD->getNumParams() == 0) {
568    // Graphics root function
569    return true;
570  }
571
572  // Check for legacy graphics root function (with single parameter).
573  if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
574    const clang::QualType &IntType = FD->getASTContext().IntTy;
575    if (FD->getReturnType().getCanonicalType() == IntType) {
576      return true;
577    }
578  }
579
580  return false;
581}
582
583bool RSExportForEach::isRSForEachFunc(unsigned int targetAPI,
584                                      slang::RSContext* Context,
585                                      const clang::FunctionDecl *FD) {
586  slangAssert(Context && FD);
587  bool hasKernelAttr = FD->hasAttr<clang::KernelAttr>();
588
589  if (FD->getStorageClass() == clang::SC_Static) {
590    if (hasKernelAttr) {
591      Context->ReportError(FD->getLocation(),
592                           "Invalid use of attribute kernel with "
593                           "static function declaration: %0")
594          << FD->getName();
595    }
596    return false;
597  }
598
599  // Anything tagged as a kernel is definitely used with ForEach.
600  if (hasKernelAttr) {
601    return true;
602  }
603
604  if (isGraphicsRootRSFunc(targetAPI, FD)) {
605    return false;
606  }
607
608  // Check if first parameter is a pointer (which is required for ForEach).
609  unsigned int numParams = FD->getNumParams();
610
611  if (numParams > 0) {
612    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
613    clang::QualType QT = PVD->getType().getCanonicalType();
614
615    if (QT->isPointerType()) {
616      return true;
617    }
618
619    // Any non-graphics root() is automatically a ForEach candidate.
620    // At this point, however, we know that it is not going to be a valid
621    // compute root() function (due to not having a pointer parameter). We
622    // still want to return true here, so that we can issue appropriate
623    // diagnostics.
624    if (isRootRSFunc(FD)) {
625      return true;
626    }
627  }
628
629  return false;
630}
631
632bool
633RSExportForEach::validateSpecialFuncDecl(unsigned int targetAPI,
634                                         slang::RSContext *Context,
635                                         clang::FunctionDecl const *FD) {
636  slangAssert(Context && FD);
637  bool valid = true;
638  const clang::ASTContext &C = FD->getASTContext();
639  const clang::QualType &IntType = FD->getASTContext().IntTy;
640
641  if (isGraphicsRootRSFunc(targetAPI, FD)) {
642    if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
643      // Legacy graphics root function
644      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
645      clang::QualType QT = PVD->getType().getCanonicalType();
646      if (QT != IntType) {
647        Context->ReportError(PVD->getLocation(),
648                             "invalid parameter type for legacy "
649                             "graphics root() function: %0")
650            << PVD->getType();
651        valid = false;
652      }
653    }
654
655    // Graphics root function, so verify that it returns an int
656    if (FD->getReturnType().getCanonicalType() != IntType) {
657      Context->ReportError(FD->getLocation(),
658                           "root() is required to return "
659                           "an int for graphics usage");
660      valid = false;
661    }
662  } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
663    if (FD->getNumParams() != 0) {
664      Context->ReportError(FD->getLocation(),
665                           "%0(void) is required to have no "
666                           "parameters")
667          << FD->getName();
668      valid = false;
669    }
670
671    if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
672      Context->ReportError(FD->getLocation(),
673                           "%0(void) is required to have a void "
674                           "return type")
675          << FD->getName();
676      valid = false;
677    }
678  } else {
679    slangAssert(false && "must be called on root, init or .rs.dtor function!");
680  }
681
682  return valid;
683}
684
685}  // namespace slang
686