1/*
2 * Copyright 2011-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Attr.h"
23#include "clang/AST/Decl.h"
24#include "clang/AST/TypeLoc.h"
25
26#include "llvm/IR/DerivedTypes.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31#include "slang_version.h"
32
33namespace slang {
34
35// This function takes care of additional validation and construction of
36// parameters related to forEach_* reflection.
37bool RSExportForEach::validateAndConstructParams(
38    RSContext *Context, const clang::FunctionDecl *FD) {
39  slangAssert(Context && FD);
40  bool valid = true;
41
42  numParams = FD->getNumParams();
43
44  if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
45    // Before JellyBean, we allowed only one kernel per file.  It must be called "root".
46    if (!isRootRSFunc(FD)) {
47      Context->ReportError(FD->getLocation(),
48                           "Non-root compute kernel %0() is "
49                           "not supported in SDK levels %1-%2")
50          << FD->getName() << SLANG_MINIMUM_TARGET_API
51          << (SLANG_JB_TARGET_API - 1);
52      return false;
53    }
54  }
55
56  mResultType = FD->getReturnType().getCanonicalType();
57  // Compute kernel functions are defined differently when the
58  // "__attribute__((kernel))" is set.
59  if (FD->hasAttr<clang::KernelAttr>()) {
60    valid |= validateAndConstructKernelParams(Context, FD);
61  } else {
62    valid |= validateAndConstructOldStyleParams(Context, FD);
63  }
64
65  valid |= setSignatureMetadata(Context, FD);
66  return valid;
67}
68
69bool RSExportForEach::validateAndConstructOldStyleParams(
70    RSContext *Context, const clang::FunctionDecl *FD) {
71  slangAssert(Context && FD);
72  // If numParams is 0, we already marked this as a graphics root().
73  slangAssert(numParams > 0);
74
75  bool valid = true;
76
77  // Compute kernel functions of this style are required to return a void type.
78  clang::ASTContext &C = Context->getASTContext();
79  if (mResultType != C.VoidTy) {
80    Context->ReportError(FD->getLocation(),
81                         "Compute kernel %0() is required to return a "
82                         "void type")
83        << FD->getName();
84    valid = false;
85  }
86
87  // Validate remaining parameter types
88  // TODO(all): Add support for LOD/face when we have them
89
90  size_t IndexOfFirstIterator = numParams;
91  valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator);
92
93  // Validate the non-iterator parameters, which should all be found before the
94  // first iterator.
95  for (size_t i = 0; i < IndexOfFirstIterator; i++) {
96    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
97    clang::QualType QT = PVD->getType().getCanonicalType();
98
99    if (!QT->isPointerType()) {
100      Context->ReportError(PVD->getLocation(),
101                           "Compute kernel %0() cannot have non-pointer "
102                           "parameters besides 'x' and 'y'. Parameter '%1' is "
103                           "of type: '%2'")
104          << FD->getName() << PVD->getName() << PVD->getType().getAsString();
105      valid = false;
106      continue;
107    }
108
109    // The only non-const pointer should be out.
110    if (!QT->getPointeeType().isConstQualified()) {
111      if (mOut == NULL) {
112        mOut = PVD;
113      } else {
114        Context->ReportError(PVD->getLocation(),
115                             "Compute kernel %0() can only have one non-const "
116                             "pointer parameter. Parameters '%1' and '%2' are "
117                             "both non-const.")
118            << FD->getName() << mOut->getName() << PVD->getName();
119        valid = false;
120      }
121    } else {
122      if (mIns.empty() && mOut == NULL) {
123        mIns.push_back(PVD);
124      } else if (mUsrData == NULL) {
125        mUsrData = PVD;
126      } else {
127        Context->ReportError(
128            PVD->getLocation(),
129            "Unexpected parameter '%0' for compute kernel %1()")
130            << PVD->getName() << FD->getName();
131        valid = false;
132      }
133    }
134  }
135
136  if (mIns.empty() && !mOut) {
137    Context->ReportError(FD->getLocation(),
138                         "Compute kernel %0() must have at least one "
139                         "parameter for in or out")
140        << FD->getName();
141    valid = false;
142  }
143
144  return valid;
145}
146
147bool RSExportForEach::validateAndConstructKernelParams(
148    RSContext *Context, const clang::FunctionDecl *FD) {
149  slangAssert(Context && FD);
150  bool valid = true;
151  clang::ASTContext &C = Context->getASTContext();
152
153  if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
154    Context->ReportError(FD->getLocation(),
155                         "Compute kernel %0() targeting SDK levels "
156                         "%1-%2 may not use pass-by-value with "
157                         "__attribute__((kernel))")
158        << FD->getName() << SLANG_MINIMUM_TARGET_API
159        << (SLANG_JB_MR1_TARGET_API - 1);
160    return false;
161  }
162
163  // Denote that we are indeed a pass-by-value kernel.
164  mIsKernelStyle = true;
165  mHasReturnType = (mResultType != C.VoidTy);
166
167  if (mResultType->isPointerType()) {
168    Context->ReportError(
169        FD->getTypeSpecStartLoc(),
170        "Compute kernel %0() cannot return a pointer type: '%1'")
171        << FD->getName() << mResultType.getAsString();
172    valid = false;
173  }
174
175  // Validate remaining parameter types
176  // TODO(all): Add support for LOD/face when we have them
177
178  size_t IndexOfFirstIterator = numParams;
179  valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator);
180
181  // Validate the non-iterator parameters, which should all be found before the
182  // first iterator.
183  for (size_t i = 0; i < IndexOfFirstIterator; i++) {
184    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
185
186    /*
187     * FIXME: Change this to a test against an actual API version when the
188     *        multi-input feature is officially supported.
189     */
190    if (Context->getTargetAPI() == SLANG_DEVELOPMENT_TARGET_API || i == 0) {
191      mIns.push_back(PVD);
192    } else {
193      Context->ReportError(PVD->getLocation(),
194                           "Invalid parameter '%0' for compute kernel %1(). "
195                           "Kernels targeting SDK levels %2-%3 may not use "
196                           "multiple input parameters.") << PVD->getName() <<
197                           FD->getName() << SLANG_MINIMUM_TARGET_API <<
198                           SLANG_MAXIMUM_TARGET_API;
199      valid = false;
200    }
201    clang::QualType QT = PVD->getType().getCanonicalType();
202    if (QT->isPointerType()) {
203      Context->ReportError(PVD->getLocation(),
204                           "Compute kernel %0() cannot have "
205                           "parameter '%1' of pointer type: '%2'")
206          << FD->getName() << PVD->getName() << PVD->getType().getAsString();
207      valid = false;
208    }
209  }
210
211  // Check that we have at least one allocation to use for dimensions.
212  if (valid && mIns.empty() && !mHasReturnType) {
213    Context->ReportError(FD->getLocation(),
214                         "Compute kernel %0() must have at least one "
215                         "input parameter or a non-void return "
216                         "type")
217        << FD->getName();
218    valid = false;
219  }
220
221  return valid;
222}
223
224// Search for the optional x and y parameters.  Returns true if valid.   Also
225// sets *IndexOfFirstIterator to the index of the first iterator parameter, or
226// FD->getNumParams() if none are found.
227bool RSExportForEach::validateIterationParameters(
228    RSContext *Context, const clang::FunctionDecl *FD,
229    size_t *IndexOfFirstIterator) {
230  slangAssert(IndexOfFirstIterator != NULL);
231  slangAssert(mX == NULL && mY == NULL);
232  clang::ASTContext &C = Context->getASTContext();
233
234  // Find the x and y parameters if present.
235  size_t NumParams = FD->getNumParams();
236  *IndexOfFirstIterator = NumParams;
237  bool valid = true;
238  for (size_t i = 0; i < NumParams; i++) {
239    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
240    llvm::StringRef ParamName = PVD->getName();
241    if (ParamName.equals("x")) {
242      slangAssert(mX == NULL);  // We won't be invoked if two 'x' are present.
243      mX = PVD;
244      if (mY != NULL) {
245        Context->ReportError(PVD->getLocation(),
246                             "In compute kernel %0(), parameter 'x' should "
247                             "be defined before parameter 'y'")
248            << FD->getName();
249        valid = false;
250      }
251    } else if (ParamName.equals("y")) {
252      slangAssert(mY == NULL);  // We won't be invoked if two 'y' are present.
253      mY = PVD;
254    } else {
255      // It's neither x nor y.
256      if (*IndexOfFirstIterator < NumParams) {
257        Context->ReportError(PVD->getLocation(),
258                             "In compute kernel %0(), parameter '%1' cannot "
259                             "appear after the 'x' and 'y' parameters")
260            << FD->getName() << ParamName;
261        valid = false;
262      }
263      continue;
264    }
265    // Validate the data type of x and y.
266    clang::QualType QT = PVD->getType().getCanonicalType();
267    clang::QualType UT = QT.getUnqualifiedType();
268    if (UT != C.UnsignedIntTy && UT != C.IntTy) {
269      Context->ReportError(PVD->getLocation(),
270                           "Parameter '%0' must be of type 'int' or "
271                           "'unsigned int'. It is of type '%1'")
272          << ParamName << PVD->getType().getAsString();
273      valid = false;
274    }
275    // If this is the first time we find an iterator, save it.
276    if (*IndexOfFirstIterator >= NumParams) {
277      *IndexOfFirstIterator = i;
278    }
279  }
280  // Check that x and y have the same type.
281  if (mX != NULL and mY != NULL) {
282    clang::QualType XType = mX->getType();
283    clang::QualType YType = mY->getType();
284
285    if (XType != YType) {
286      Context->ReportError(mY->getLocation(),
287                           "Parameter 'x' and 'y' must be of the same type. "
288                           "'x' is of type '%0' while 'y' is of type '%1'")
289          << XType.getAsString() << YType.getAsString();
290      valid = false;
291    }
292  }
293  return valid;
294}
295
296bool RSExportForEach::setSignatureMetadata(RSContext *Context,
297                                           const clang::FunctionDecl *FD) {
298  mSignatureMetadata = 0;
299  bool valid = true;
300
301  if (mIsKernelStyle) {
302    slangAssert(mOut == NULL);
303    slangAssert(mUsrData == NULL);
304  } else {
305    slangAssert(!mHasReturnType);
306  }
307
308  // Set up the bitwise metadata encoding for runtime argument passing.
309  // TODO: If this bit field is re-used from C++ code, define the values in a header.
310  const bool HasOut = mOut || mHasReturnType;
311  mSignatureMetadata |= (hasIns() ?       0x01 : 0);
312  mSignatureMetadata |= (HasOut ?         0x02 : 0);
313  mSignatureMetadata |= (mUsrData ?       0x04 : 0);
314  mSignatureMetadata |= (mX ?             0x08 : 0);
315  mSignatureMetadata |= (mY ?             0x10 : 0);
316  mSignatureMetadata |= (mIsKernelStyle ? 0x20 : 0);  // pass-by-value
317
318  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
319    // APIs before ICS cannot skip between parameters. It is ok, however, for
320    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
321    if (mSignatureMetadata != 0x1f &&  // In, Out, UsrData, X, Y
322        mSignatureMetadata != 0x0f &&  // In, Out, UsrData, X
323        mSignatureMetadata != 0x07 &&  // In, Out, UsrData
324        mSignatureMetadata != 0x03 &&  // In, Out
325        mSignatureMetadata != 0x01) {  // In
326      Context->ReportError(FD->getLocation(),
327                           "Compute kernel %0() targeting SDK levels "
328                           "%1-%2 may not skip parameters")
329          << FD->getName() << SLANG_MINIMUM_TARGET_API
330          << (SLANG_ICS_TARGET_API - 1);
331      valid = false;
332    }
333  }
334  return valid;
335}
336
337RSExportForEach *RSExportForEach::Create(RSContext *Context,
338                                         const clang::FunctionDecl *FD) {
339  slangAssert(Context && FD);
340  llvm::StringRef Name = FD->getName();
341  RSExportForEach *FE;
342
343  slangAssert(!Name.empty() && "Function must have a name");
344
345  FE = new RSExportForEach(Context, Name);
346
347  if (!FE->validateAndConstructParams(Context, FD)) {
348    return NULL;
349  }
350
351  clang::ASTContext &Ctx = Context->getASTContext();
352
353  std::string Id = CreateDummyName("helper_foreach_param", FE->getName());
354
355  // Extract the usrData parameter (if we have one)
356  if (FE->mUsrData) {
357    const clang::ParmVarDecl *PVD = FE->mUsrData;
358    clang::QualType QT = PVD->getType().getCanonicalType();
359    slangAssert(QT->isPointerType() &&
360                QT->getPointeeType().isConstQualified());
361
362    const clang::ASTContext &C = Context->getASTContext();
363    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
364        C.VoidTy) {
365      // In the case of using const void*, we can't reflect an appopriate
366      // Java type, so we fall back to just reflecting the ain/aout parameters
367      FE->mUsrData = NULL;
368    } else {
369      clang::RecordDecl *RD =
370          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
371                                    Ctx.getTranslationUnitDecl(),
372                                    clang::SourceLocation(),
373                                    clang::SourceLocation(),
374                                    &Ctx.Idents.get(Id));
375
376      clang::FieldDecl *FD =
377          clang::FieldDecl::Create(Ctx,
378                                   RD,
379                                   clang::SourceLocation(),
380                                   clang::SourceLocation(),
381                                   PVD->getIdentifier(),
382                                   QT->getPointeeType(),
383                                   NULL,
384                                   /* BitWidth = */ NULL,
385                                   /* Mutable = */ false,
386                                   /* HasInit = */ clang::ICIS_NoInit);
387      RD->addDecl(FD);
388      RD->completeDefinition();
389
390      // Create an export type iff we have a valid usrData type
391      clang::QualType T = Ctx.getTagDeclType(RD);
392      slangAssert(!T.isNull());
393
394      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
395
396      if (ET == NULL) {
397        fprintf(stderr, "Failed to export the function %s. There's at least "
398                        "one parameter whose type is not supported by the "
399                        "reflection\n", FE->getName().c_str());
400        return NULL;
401      }
402
403      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
404                  "Parameter packet must be a record");
405
406      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
407    }
408  }
409
410  if (FE->hasIns()) {
411
412    for (InIter BI = FE->mIns.begin(), EI = FE->mIns.end(); BI != EI; BI++) {
413      const clang::Type *T = (*BI)->getType().getCanonicalType().getTypePtr();
414      RSExportType *InExportType = RSExportType::Create(Context, T);
415
416      if (FE->mIsKernelStyle) {
417        slangAssert(InExportType != NULL);
418      }
419
420      FE->mInTypes.push_back(InExportType);
421    }
422  }
423
424  if (FE->mIsKernelStyle && FE->mHasReturnType) {
425    const clang::Type *T = FE->mResultType.getTypePtr();
426    FE->mOutType = RSExportType::Create(Context, T);
427    slangAssert(FE->mOutType);
428  } else if (FE->mOut) {
429    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
430    FE->mOutType = RSExportType::Create(Context, T);
431  }
432
433  return FE;
434}
435
436RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
437  slangAssert(Context);
438  llvm::StringRef Name = "root";
439  RSExportForEach *FE = new RSExportForEach(Context, Name);
440  FE->mDummyRoot = true;
441  return FE;
442}
443
444bool RSExportForEach::isGraphicsRootRSFunc(unsigned int targetAPI,
445                                           const clang::FunctionDecl *FD) {
446  if (FD->hasAttr<clang::KernelAttr>()) {
447    return false;
448  }
449
450  if (!isRootRSFunc(FD)) {
451    return false;
452  }
453
454  if (FD->getNumParams() == 0) {
455    // Graphics root function
456    return true;
457  }
458
459  // Check for legacy graphics root function (with single parameter).
460  if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
461    const clang::QualType &IntType = FD->getASTContext().IntTy;
462    if (FD->getReturnType().getCanonicalType() == IntType) {
463      return true;
464    }
465  }
466
467  return false;
468}
469
470bool RSExportForEach::isRSForEachFunc(unsigned int targetAPI,
471                                      slang::RSContext* Context,
472                                      const clang::FunctionDecl *FD) {
473  slangAssert(Context && FD);
474  bool hasKernelAttr = FD->hasAttr<clang::KernelAttr>();
475
476  if (FD->getStorageClass() == clang::SC_Static) {
477    if (hasKernelAttr) {
478      Context->ReportError(FD->getLocation(),
479                           "Invalid use of attribute kernel with "
480                           "static function declaration: %0")
481          << FD->getName();
482    }
483    return false;
484  }
485
486  // Anything tagged as a kernel is definitely used with ForEach.
487  if (hasKernelAttr) {
488    return true;
489  }
490
491  if (isGraphicsRootRSFunc(targetAPI, FD)) {
492    return false;
493  }
494
495  // Check if first parameter is a pointer (which is required for ForEach).
496  unsigned int numParams = FD->getNumParams();
497
498  if (numParams > 0) {
499    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
500    clang::QualType QT = PVD->getType().getCanonicalType();
501
502    if (QT->isPointerType()) {
503      return true;
504    }
505
506    // Any non-graphics root() is automatically a ForEach candidate.
507    // At this point, however, we know that it is not going to be a valid
508    // compute root() function (due to not having a pointer parameter). We
509    // still want to return true here, so that we can issue appropriate
510    // diagnostics.
511    if (isRootRSFunc(FD)) {
512      return true;
513    }
514  }
515
516  return false;
517}
518
519bool
520RSExportForEach::validateSpecialFuncDecl(unsigned int targetAPI,
521                                         slang::RSContext *Context,
522                                         clang::FunctionDecl const *FD) {
523  slangAssert(Context && FD);
524  bool valid = true;
525  const clang::ASTContext &C = FD->getASTContext();
526  const clang::QualType &IntType = FD->getASTContext().IntTy;
527
528  if (isGraphicsRootRSFunc(targetAPI, FD)) {
529    if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
530      // Legacy graphics root function
531      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
532      clang::QualType QT = PVD->getType().getCanonicalType();
533      if (QT != IntType) {
534        Context->ReportError(PVD->getLocation(),
535                             "invalid parameter type for legacy "
536                             "graphics root() function: %0")
537            << PVD->getType();
538        valid = false;
539      }
540    }
541
542    // Graphics root function, so verify that it returns an int
543    if (FD->getReturnType().getCanonicalType() != IntType) {
544      Context->ReportError(FD->getLocation(),
545                           "root() is required to return "
546                           "an int for graphics usage");
547      valid = false;
548    }
549  } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
550    if (FD->getNumParams() != 0) {
551      Context->ReportError(FD->getLocation(),
552                           "%0(void) is required to have no "
553                           "parameters")
554          << FD->getName();
555      valid = false;
556    }
557
558    if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
559      Context->ReportError(FD->getLocation(),
560                           "%0(void) is required to have a void "
561                           "return type")
562          << FD->getName();
563      valid = false;
564    }
565  } else {
566    slangAssert(false && "must be called on root, init or .rs.dtor function!");
567  }
568
569  return valid;
570}
571
572}  // namespace slang
573