slang_rs_export_foreach.cpp revision ee4016d1247d3fbe50822de279d3da273d8aef4c
1/*
2 * Copyright 2011-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Attr.h"
23#include "clang/AST/Decl.h"
24#include "clang/AST/TypeLoc.h"
25
26#include "llvm/IR/DerivedTypes.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31#include "slang_version.h"
32
33namespace slang {
34
35// This function takes care of additional validation and construction of
36// parameters related to forEach_* reflection.
37bool RSExportForEach::validateAndConstructParams(
38    RSContext *Context, const clang::FunctionDecl *FD) {
39  slangAssert(Context && FD);
40  bool valid = true;
41
42  numParams = FD->getNumParams();
43
44  if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
45    // Before JellyBean, we allowed only one kernel per file.  It must be called "root".
46    if (!isRootRSFunc(FD)) {
47      Context->ReportError(FD->getLocation(),
48                           "Non-root compute kernel %0() is "
49                           "not supported in SDK levels %1-%2")
50          << FD->getName() << SLANG_MINIMUM_TARGET_API
51          << (SLANG_JB_TARGET_API - 1);
52      return false;
53    }
54  }
55
56  mResultType = FD->getReturnType().getCanonicalType();
57  // Compute kernel functions are defined differently when the
58  // "__attribute__((kernel))" is set.
59  if (FD->hasAttr<clang::KernelAttr>()) {
60    valid |= validateAndConstructKernelParams(Context, FD);
61  } else {
62    valid |= validateAndConstructOldStyleParams(Context, FD);
63  }
64  valid |= setSignatureMetadata(Context, FD);
65  return valid;
66}
67
68bool RSExportForEach::validateAndConstructOldStyleParams(
69    RSContext *Context, const clang::FunctionDecl *FD) {
70  slangAssert(Context && FD);
71  // If numParams is 0, we already marked this as a graphics root().
72  slangAssert(numParams > 0);
73
74  bool valid = true;
75
76  // Compute kernel functions of this style are required to return a void type.
77  clang::ASTContext &C = Context->getASTContext();
78  if (mResultType != C.VoidTy) {
79    Context->ReportError(FD->getLocation(),
80                         "Compute kernel %0() is required to return a "
81                         "void type")
82        << FD->getName();
83    valid = false;
84  }
85
86  // Validate remaining parameter types
87  // TODO(all): Add support for LOD/face when we have them
88
89  size_t IndexOfFirstIterator = numParams;
90  valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator);
91
92  // Validate the non-iterator parameters, which should all be found before the
93  // first iterator.
94  for (size_t i = 0; i < IndexOfFirstIterator; i++) {
95    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
96    clang::QualType QT = PVD->getType().getCanonicalType();
97
98    if (!QT->isPointerType()) {
99      Context->ReportError(PVD->getLocation(),
100                           "Compute kernel %0() cannot have non-pointer "
101                           "parameters besides 'x' and 'y'. Parameter '%1' is "
102                           "of type: '%2'")
103          << FD->getName() << PVD->getName() << PVD->getType().getAsString();
104      valid = false;
105      continue;
106    }
107
108    // The only non-const pointer should be out.
109    if (!QT->getPointeeType().isConstQualified()) {
110      if (mOut == NULL) {
111        mOut = PVD;
112      } else {
113        Context->ReportError(PVD->getLocation(),
114                             "Compute kernel %0() can only have one non-const "
115                             "pointer parameter. Parameters '%1' and '%2' are "
116                             "both non-const.")
117            << FD->getName() << mOut->getName() << PVD->getName();
118        valid = false;
119      }
120    } else {
121      if (mIn == NULL && mOut == NULL) {
122        mIn = PVD;
123      } else if (mUsrData == NULL) {
124        mUsrData = PVD;
125      } else {
126        Context->ReportError(
127            PVD->getLocation(),
128            "Unexpected parameter '%0' for compute kernel %1()")
129            << PVD->getName() << FD->getName();
130        valid = false;
131      }
132    }
133  }
134
135  if (!mIn && !mOut) {
136    Context->ReportError(FD->getLocation(),
137                         "Compute kernel %0() must have at least one "
138                         "parameter for in or out")
139        << FD->getName();
140    valid = false;
141  }
142
143  return valid;
144}
145
146bool RSExportForEach::validateAndConstructKernelParams(
147    RSContext *Context, const clang::FunctionDecl *FD) {
148  slangAssert(Context && FD);
149  bool valid = true;
150  clang::ASTContext &C = Context->getASTContext();
151
152  if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
153    Context->ReportError(FD->getLocation(),
154                         "Compute kernel %0() targeting SDK levels "
155                         "%1-%2 may not use pass-by-value with "
156                         "__attribute__((kernel))")
157        << FD->getName() << SLANG_MINIMUM_TARGET_API
158        << (SLANG_JB_MR1_TARGET_API - 1);
159    return false;
160  }
161
162  // Denote that we are indeed a pass-by-value kernel.
163  mIsKernelStyle = true;
164  mHasReturnType = (mResultType != C.VoidTy);
165
166  if (mResultType->isPointerType()) {
167    Context->ReportError(
168        FD->getTypeSpecStartLoc(),
169        "Compute kernel %0() cannot return a pointer type: '%1'")
170        << FD->getName() << mResultType.getAsString();
171    valid = false;
172  }
173
174  // Validate remaining parameter types
175  // TODO(all): Add support for LOD/face when we have them
176
177  size_t IndexOfFirstIterator = numParams;
178  valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator);
179
180  // Validate the non-iterator parameters, which should all be found before the
181  // first iterator.
182  for (size_t i = 0; i < IndexOfFirstIterator; i++) {
183    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
184    if (i == 0) {
185      mIn = PVD;
186    } else {
187      Context->ReportError(PVD->getLocation(),
188                           "Unrecognized parameter '%0'. Compute kernel %1() "
189                           "can only have one input parameter, 'x', and 'y'")
190          << PVD->getName() << FD->getName();
191      valid = false;
192    }
193    clang::QualType QT = PVD->getType().getCanonicalType();
194    if (QT->isPointerType()) {
195      Context->ReportError(PVD->getLocation(),
196                           "Compute kernel %0() cannot have "
197                           "parameter '%1' of pointer type: '%2'")
198          << FD->getName() << PVD->getName() << PVD->getType().getAsString();
199      valid = false;
200    }
201  }
202
203  // Check that we have at least one allocation to use for dimensions.
204  if (valid && !mIn && !mHasReturnType) {
205    Context->ReportError(FD->getLocation(),
206                         "Compute kernel %0() must have at least one "
207                         "input parameter or a non-void return "
208                         "type")
209        << FD->getName();
210    valid = false;
211  }
212
213  return valid;
214}
215
216// Search for the optional x and y parameters.  Returns true if valid.   Also
217// sets *IndexOfFirstIterator to the index of the first iterator parameter, or
218// FD->getNumParams() if none are found.
219bool RSExportForEach::validateIterationParameters(
220    RSContext *Context, const clang::FunctionDecl *FD,
221    size_t *IndexOfFirstIterator) {
222  slangAssert(IndexOfFirstIterator != NULL);
223  slangAssert(mX == NULL && mY == NULL);
224  clang::ASTContext &C = Context->getASTContext();
225
226  // Find the x and y parameters if present.
227  size_t NumParams = FD->getNumParams();
228  *IndexOfFirstIterator = NumParams;
229  bool valid = true;
230  for (size_t i = 0; i < NumParams; i++) {
231    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
232    llvm::StringRef ParamName = PVD->getName();
233    if (ParamName.equals("x")) {
234      slangAssert(mX == NULL);  // We won't be invoked if two 'x' are present.
235      mX = PVD;
236      if (mY != NULL) {
237        Context->ReportError(PVD->getLocation(),
238                             "In compute kernel %0(), parameter 'x' should "
239                             "be defined before parameter 'y'")
240            << FD->getName();
241        valid = false;
242      }
243    } else if (ParamName.equals("y")) {
244      slangAssert(mY == NULL);  // We won't be invoked if two 'y' are present.
245      mY = PVD;
246    } else {
247      // It's neither x nor y.
248      if (*IndexOfFirstIterator < NumParams) {
249        Context->ReportError(PVD->getLocation(),
250                             "In compute kernel %0(), parameter '%1' cannot "
251                             "appear after the 'x' and 'y' parameters")
252            << FD->getName() << ParamName;
253        valid = false;
254      }
255      continue;
256    }
257    // Validate the data type of x and y.
258    clang::QualType QT = PVD->getType().getCanonicalType();
259    clang::QualType UT = QT.getUnqualifiedType();
260    if (UT != C.UnsignedIntTy && UT != C.IntTy) {
261      Context->ReportError(PVD->getLocation(),
262                           "Parameter '%0' must be of type 'int' or "
263                           "'unsigned int'. It is of type '%1'")
264          << ParamName << PVD->getType().getAsString();
265      valid = false;
266    }
267    // If this is the first time we find an iterator, save it.
268    if (*IndexOfFirstIterator >= NumParams) {
269      *IndexOfFirstIterator = i;
270    }
271  }
272  // Check that x and y have the same type.
273  if (mX != NULL and mY != NULL) {
274    clang::QualType XType = mX->getType();
275    clang::QualType YType = mY->getType();
276
277    if (XType != YType) {
278      Context->ReportError(mY->getLocation(),
279                           "Parameter 'x' and 'y' must be of the same type. "
280                           "'x' is of type '%0' while 'y' is of type '%1'")
281          << XType.getAsString() << YType.getAsString();
282      valid = false;
283    }
284  }
285  return valid;
286}
287
288bool RSExportForEach::setSignatureMetadata(RSContext *Context,
289                                           const clang::FunctionDecl *FD) {
290  mSignatureMetadata = 0;
291  bool valid = true;
292
293  if (mIsKernelStyle) {
294    slangAssert(mOut == NULL);
295    slangAssert(mUsrData == NULL);
296  } else {
297    slangAssert(!mHasReturnType);
298  }
299
300  // Set up the bitwise metadata encoding for runtime argument passing.
301  // TODO: If this bit field is re-used from C++ code, define the values in a header.
302  const bool HasOut = mOut || mHasReturnType;
303  mSignatureMetadata |= (mIn ?            0x01 : 0);
304  mSignatureMetadata |= (HasOut ?         0x02 : 0);
305  mSignatureMetadata |= (mUsrData ?       0x04 : 0);
306  mSignatureMetadata |= (mX ?             0x08 : 0);
307  mSignatureMetadata |= (mY ?             0x10 : 0);
308  mSignatureMetadata |= (mIsKernelStyle ? 0x20 : 0);  // pass-by-value
309
310  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
311    // APIs before ICS cannot skip between parameters. It is ok, however, for
312    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
313    if (mSignatureMetadata != 0x1f &&  // In, Out, UsrData, X, Y
314        mSignatureMetadata != 0x0f &&  // In, Out, UsrData, X
315        mSignatureMetadata != 0x07 &&  // In, Out, UsrData
316        mSignatureMetadata != 0x03 &&  // In, Out
317        mSignatureMetadata != 0x01) {  // In
318      Context->ReportError(FD->getLocation(),
319                           "Compute kernel %0() targeting SDK levels "
320                           "%1-%2 may not skip parameters")
321          << FD->getName() << SLANG_MINIMUM_TARGET_API
322          << (SLANG_ICS_TARGET_API - 1);
323      valid = false;
324    }
325  }
326  return valid;
327}
328
329RSExportForEach *RSExportForEach::Create(RSContext *Context,
330                                         const clang::FunctionDecl *FD) {
331  slangAssert(Context && FD);
332  llvm::StringRef Name = FD->getName();
333  RSExportForEach *FE;
334
335  slangAssert(!Name.empty() && "Function must have a name");
336
337  FE = new RSExportForEach(Context, Name);
338
339  if (!FE->validateAndConstructParams(Context, FD)) {
340    return NULL;
341  }
342
343  clang::ASTContext &Ctx = Context->getASTContext();
344
345  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
346  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
347
348  // Extract the usrData parameter (if we have one)
349  if (FE->mUsrData) {
350    const clang::ParmVarDecl *PVD = FE->mUsrData;
351    clang::QualType QT = PVD->getType().getCanonicalType();
352    slangAssert(QT->isPointerType() &&
353                QT->getPointeeType().isConstQualified());
354
355    const clang::ASTContext &C = Context->getASTContext();
356    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
357        C.VoidTy) {
358      // In the case of using const void*, we can't reflect an appopriate
359      // Java type, so we fall back to just reflecting the ain/aout parameters
360      FE->mUsrData = NULL;
361    } else {
362      clang::RecordDecl *RD =
363          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
364                                    Ctx.getTranslationUnitDecl(),
365                                    clang::SourceLocation(),
366                                    clang::SourceLocation(),
367                                    &Ctx.Idents.get(Id));
368
369      clang::FieldDecl *FD =
370          clang::FieldDecl::Create(Ctx,
371                                   RD,
372                                   clang::SourceLocation(),
373                                   clang::SourceLocation(),
374                                   PVD->getIdentifier(),
375                                   QT->getPointeeType(),
376                                   NULL,
377                                   /* BitWidth = */ NULL,
378                                   /* Mutable = */ false,
379                                   /* HasInit = */ clang::ICIS_NoInit);
380      RD->addDecl(FD);
381      RD->completeDefinition();
382
383      // Create an export type iff we have a valid usrData type
384      clang::QualType T = Ctx.getTagDeclType(RD);
385      slangAssert(!T.isNull());
386
387      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
388
389      if (ET == NULL) {
390        fprintf(stderr, "Failed to export the function %s. There's at least "
391                        "one parameter whose type is not supported by the "
392                        "reflection\n", FE->getName().c_str());
393        return NULL;
394      }
395
396      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
397                  "Parameter packet must be a record");
398
399      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
400    }
401  }
402
403  if (FE->mIn) {
404    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
405    FE->mInType = RSExportType::Create(Context, T);
406    if (FE->mIsKernelStyle) {
407      slangAssert(FE->mInType);
408    }
409  }
410
411  if (FE->mIsKernelStyle && FE->mHasReturnType) {
412    const clang::Type *T = FE->mResultType.getTypePtr();
413    FE->mOutType = RSExportType::Create(Context, T);
414    slangAssert(FE->mOutType);
415  } else if (FE->mOut) {
416    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
417    FE->mOutType = RSExportType::Create(Context, T);
418  }
419
420  return FE;
421}
422
423RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
424  slangAssert(Context);
425  llvm::StringRef Name = "root";
426  RSExportForEach *FE = new RSExportForEach(Context, Name);
427  FE->mDummyRoot = true;
428  return FE;
429}
430
431bool RSExportForEach::isGraphicsRootRSFunc(int targetAPI,
432                                           const clang::FunctionDecl *FD) {
433  if (FD->hasAttr<clang::KernelAttr>()) {
434    return false;
435  }
436
437  if (!isRootRSFunc(FD)) {
438    return false;
439  }
440
441  if (FD->getNumParams() == 0) {
442    // Graphics root function
443    return true;
444  }
445
446  // Check for legacy graphics root function (with single parameter).
447  if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
448    const clang::QualType &IntType = FD->getASTContext().IntTy;
449    if (FD->getReturnType().getCanonicalType() == IntType) {
450      return true;
451    }
452  }
453
454  return false;
455}
456
457bool RSExportForEach::isRSForEachFunc(int targetAPI,
458    slang::RSContext* Context,
459    const clang::FunctionDecl *FD) {
460  slangAssert(Context && FD);
461  bool hasKernelAttr = FD->hasAttr<clang::KernelAttr>();
462
463  if (FD->getStorageClass() == clang::SC_Static) {
464    if (hasKernelAttr) {
465      Context->ReportError(FD->getLocation(),
466                           "Invalid use of attribute kernel with "
467                           "static function declaration: %0")
468          << FD->getName();
469    }
470    return false;
471  }
472
473  // Anything tagged as a kernel is definitely used with ForEach.
474  if (hasKernelAttr) {
475    return true;
476  }
477
478  if (isGraphicsRootRSFunc(targetAPI, FD)) {
479    return false;
480  }
481
482  // Check if first parameter is a pointer (which is required for ForEach).
483  unsigned int numParams = FD->getNumParams();
484
485  if (numParams > 0) {
486    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
487    clang::QualType QT = PVD->getType().getCanonicalType();
488
489    if (QT->isPointerType()) {
490      return true;
491    }
492
493    // Any non-graphics root() is automatically a ForEach candidate.
494    // At this point, however, we know that it is not going to be a valid
495    // compute root() function (due to not having a pointer parameter). We
496    // still want to return true here, so that we can issue appropriate
497    // diagnostics.
498    if (isRootRSFunc(FD)) {
499      return true;
500    }
501  }
502
503  return false;
504}
505
506bool
507RSExportForEach::validateSpecialFuncDecl(int targetAPI,
508                                         slang::RSContext *Context,
509                                         clang::FunctionDecl const *FD) {
510  slangAssert(Context && FD);
511  bool valid = true;
512  const clang::ASTContext &C = FD->getASTContext();
513  const clang::QualType &IntType = FD->getASTContext().IntTy;
514
515  if (isGraphicsRootRSFunc(targetAPI, FD)) {
516    if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
517      // Legacy graphics root function
518      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
519      clang::QualType QT = PVD->getType().getCanonicalType();
520      if (QT != IntType) {
521        Context->ReportError(PVD->getLocation(),
522                             "invalid parameter type for legacy "
523                             "graphics root() function: %0")
524            << PVD->getType();
525        valid = false;
526      }
527    }
528
529    // Graphics root function, so verify that it returns an int
530    if (FD->getReturnType().getCanonicalType() != IntType) {
531      Context->ReportError(FD->getLocation(),
532                           "root() is required to return "
533                           "an int for graphics usage");
534      valid = false;
535    }
536  } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
537    if (FD->getNumParams() != 0) {
538      Context->ReportError(FD->getLocation(),
539                           "%0(void) is required to have no "
540                           "parameters")
541          << FD->getName();
542      valid = false;
543    }
544
545    if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
546      Context->ReportError(FD->getLocation(),
547                           "%0(void) is required to have a void "
548                           "return type")
549          << FD->getName();
550      valid = false;
551    }
552  } else {
553    slangAssert(false && "must be called on root, init or .rs.dtor function!");
554  }
555
556  return valid;
557}
558
559}  // namespace slang
560