slang_rs_export_foreach.cpp revision 42f81b2b44205f421c6bd4727ce8c25b0effcb55
1/*
2 * Copyright 2011-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Attr.h"
23#include "clang/AST/Decl.h"
24#include "clang/AST/TypeLoc.h"
25
26#include "llvm/IR/DerivedTypes.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31#include "slang_version.h"
32
33namespace slang {
34
35// This function takes care of additional validation and construction of
36// parameters related to forEach_* reflection.
37bool RSExportForEach::validateAndConstructParams(
38    RSContext *Context, const clang::FunctionDecl *FD) {
39  slangAssert(Context && FD);
40  bool valid = true;
41
42  numParams = FD->getNumParams();
43
44  if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
45    if (!isRootRSFunc(FD)) {
46      Context->ReportError(FD->getLocation(),
47                           "Non-root compute kernel %0() is "
48                           "not supported in SDK levels %1-%2")
49          << FD->getName() << SLANG_MINIMUM_TARGET_API
50          << (SLANG_JB_TARGET_API - 1);
51      return false;
52    }
53  }
54
55  mResultType = FD->getResultType().getCanonicalType();
56  // Compute kernel functions are defined differently when the
57  // "__attribute__((kernel))" is set.
58  if (FD->hasAttr<clang::KernelAttr>()) {
59    valid |= validateAndConstructKernelParams(Context, FD);
60  } else {
61    valid |= validateAndConstructOldStyleParams(Context, FD);
62  }
63  valid |= setSignatureMetadata(Context, FD);
64  return valid;
65}
66
67bool RSExportForEach::validateAndConstructOldStyleParams(
68    RSContext *Context, const clang::FunctionDecl *FD) {
69  slangAssert(Context && FD);
70  // If numParams is 0, we already marked this as a graphics root().
71  slangAssert(numParams > 0);
72
73  bool valid = true;
74
75  // Compute kernel functions of this style are required to return a void type.
76  clang::ASTContext &C = Context->getASTContext();
77  if (mResultType != C.VoidTy) {
78    Context->ReportError(FD->getLocation(),
79                         "Compute kernel %0() is required to return a "
80                         "void type")
81        << FD->getName();
82    valid = false;
83  }
84
85  // Validate remaining parameter types
86  // TODO(all): Add support for LOD/face when we have them
87
88  size_t IndexOfFirstIterator = numParams;
89  valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator);
90
91  // Validate the non-iterator parameters, which should all be found before the
92  // first iterator.
93  for (size_t i = 0; i < IndexOfFirstIterator; i++) {
94    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
95    clang::QualType QT = PVD->getType().getCanonicalType();
96
97    if (!QT->isPointerType()) {
98      Context->ReportError(PVD->getLocation(),
99                           "Compute kernel %0() cannot have non-pointer "
100                           "parameters besides 'x' and 'y'. Parameter '%1' is "
101                           "of type: '%2'")
102          << FD->getName() << PVD->getName() << PVD->getType().getAsString();
103      valid = false;
104      continue;
105    }
106
107    // The only non-const pointer should be out.
108    if (!QT->getPointeeType().isConstQualified()) {
109      if (mOut == NULL) {
110        mOut = PVD;
111      } else {
112        Context->ReportError(PVD->getLocation(),
113                             "Compute kernel %0() can only have one non-const "
114                             "pointer parameter. Parameters '%1' and '%2' are "
115                             "both non-const.")
116            << FD->getName() << mOut->getName() << PVD->getName();
117        valid = false;
118      }
119    } else {
120      if (mIn == NULL && mOut == NULL) {
121        mIn = PVD;
122      } else if (mUsrData == NULL) {
123        mUsrData = PVD;
124      } else {
125        Context->ReportError(
126            PVD->getLocation(),
127            "Unexpected parameter '%0' for compute kernel %1()")
128            << PVD->getName() << FD->getName();
129        valid = false;
130      }
131    }
132  }
133
134  if (!mIn && !mOut) {
135    Context->ReportError(FD->getLocation(),
136                         "Compute kernel %0() must have at least one "
137                         "parameter for in or out")
138        << FD->getName();
139    valid = false;
140  }
141
142  return valid;
143}
144
145bool RSExportForEach::validateAndConstructKernelParams(
146    RSContext *Context, const clang::FunctionDecl *FD) {
147  slangAssert(Context && FD);
148  bool valid = true;
149  clang::ASTContext &C = Context->getASTContext();
150
151  if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
152    Context->ReportError(FD->getLocation(),
153                         "Compute kernel %0() targeting SDK levels "
154                         "%1-%2 may not use pass-by-value with "
155                         "__attribute__((kernel))")
156        << FD->getName() << SLANG_MINIMUM_TARGET_API
157        << (SLANG_JB_MR1_TARGET_API - 1);
158    return false;
159  }
160
161  // Denote that we are indeed a pass-by-value kernel.
162  mIsKernelStyle = true;
163  mHasReturnType = (mResultType != C.VoidTy);
164
165  if (mResultType->isPointerType()) {
166    Context->ReportError(
167        FD->getTypeSpecStartLoc(),
168        "Compute kernel %0() cannot return a pointer type: '%1'")
169        << FD->getName() << mResultType.getAsString();
170    valid = false;
171  }
172
173  // Validate remaining parameter types
174  // TODO(all): Add support for LOD/face when we have them
175
176  size_t IndexOfFirstIterator = numParams;
177  valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator);
178
179  // Validate the non-iterator parameters, which should all be found before the
180  // first iterator.
181  for (size_t i = 0; i < IndexOfFirstIterator; i++) {
182    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
183    if (i == 0) {
184      mIn = PVD;
185    } else {
186      Context->ReportError(PVD->getLocation(),
187                           "Unrecognized parameter '%0'. Compute kernel %1() "
188                           "can only have one input parameter, 'x', and 'y'")
189          << PVD->getName() << FD->getName();
190      valid = false;
191    }
192    clang::QualType QT = PVD->getType().getCanonicalType();
193    if (QT->isPointerType()) {
194      Context->ReportError(PVD->getLocation(),
195                           "Compute kernel %0() cannot have "
196                           "parameter '%1' of pointer type: '%2'")
197          << FD->getName() << PVD->getName() << PVD->getType().getAsString();
198      valid = false;
199    }
200  }
201
202  // Check that we have at least one allocation to use for dimensions.
203  if (valid && !mIn && !mHasReturnType) {
204    Context->ReportError(FD->getLocation(),
205                         "Compute kernel %0() must have at least one "
206                         "input parameter or a non-void return "
207                         "type")
208        << FD->getName();
209    valid = false;
210  }
211
212  return valid;
213}
214
215// Search for the optional x and y parameters.  Returns true if valid.   Also
216// sets *IndexOfFirstIterator to the index of the first iterator parameter, or
217// FD->getNumParams() if none are found.
218bool RSExportForEach::validateIterationParameters(
219    RSContext *Context, const clang::FunctionDecl *FD,
220    size_t *IndexOfFirstIterator) {
221  slangAssert(IndexOfFirstIterator != NULL);
222  slangAssert(mX == NULL && mY == NULL);
223  clang::ASTContext &C = Context->getASTContext();
224
225  // Find the x and y parameters if present.
226  size_t NumParams = FD->getNumParams();
227  *IndexOfFirstIterator = NumParams;
228  bool valid = true;
229  for (size_t i = 0; i < NumParams; i++) {
230    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
231    llvm::StringRef ParamName = PVD->getName();
232    if (ParamName.equals("x")) {
233      slangAssert(mX == NULL);  // We won't be invoked if two 'x' are present.
234      mX = PVD;
235      if (mY != NULL) {
236        Context->ReportError(PVD->getLocation(),
237                             "In compute kernel %0(), parameter 'x' should "
238                             "be defined before parameter 'y'")
239            << FD->getName();
240        valid = false;
241      }
242    } else if (ParamName.equals("y")) {
243      slangAssert(mY == NULL);  // We won't be invoked if two 'y' are present.
244      mY = PVD;
245    } else {
246      // It's neither x nor y.
247      if (*IndexOfFirstIterator < NumParams) {
248        Context->ReportError(PVD->getLocation(),
249                             "In compute kernel %0(), parameter '%1' cannot "
250                             "appear after the 'x' and 'y' parameters")
251            << FD->getName() << ParamName;
252        valid = false;
253      }
254      continue;
255    }
256    // Validate the data type of x and y.
257    clang::QualType QT = PVD->getType().getCanonicalType();
258    if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
259      Context->ReportError(
260          PVD->getLocation(),
261          "Parameter '%0' must be of type 'unsigned int'. It is of type '%1'")
262          << ParamName << PVD->getType().getAsString();
263      valid = false;
264    }
265    // If this is the first time we find an iterator, save it.
266    if (*IndexOfFirstIterator >= NumParams) {
267      *IndexOfFirstIterator = i;
268    }
269  }
270  return valid;
271}
272
273bool RSExportForEach::setSignatureMetadata(RSContext *Context,
274                                           const clang::FunctionDecl *FD) {
275  mSignatureMetadata = 0;
276  bool valid = true;
277
278  if (mIsKernelStyle) {
279    slangAssert(mOut == NULL);
280    slangAssert(mUsrData == NULL);
281  } else {
282    slangAssert(!mHasReturnType);
283  }
284
285  // Set up the bitwise metadata encoding for runtime argument passing.
286  // TODO: If this bit field is re-used from C++ code, define the values in a header.
287  const bool HasOut = mOut || mHasReturnType;
288  mSignatureMetadata |= (mIn ?            0x01 : 0);
289  mSignatureMetadata |= (HasOut ?         0x02 : 0);
290  mSignatureMetadata |= (mUsrData ?       0x04 : 0);
291  mSignatureMetadata |= (mX ?             0x08 : 0);
292  mSignatureMetadata |= (mY ?             0x10 : 0);
293  mSignatureMetadata |= (mIsKernelStyle ? 0x20 : 0);  // pass-by-value
294
295  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
296    // APIs before ICS cannot skip between parameters. It is ok, however, for
297    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
298    if (mSignatureMetadata != 0x1f &&  // In, Out, UsrData, X, Y
299        mSignatureMetadata != 0x0f &&  // In, Out, UsrData, X
300        mSignatureMetadata != 0x07 &&  // In, Out, UsrData
301        mSignatureMetadata != 0x03 &&  // In, Out
302        mSignatureMetadata != 0x01) {  // In
303      Context->ReportError(FD->getLocation(),
304                           "Compute kernel %0() targeting SDK levels "
305                           "%1-%2 may not skip parameters")
306          << FD->getName() << SLANG_MINIMUM_TARGET_API
307          << (SLANG_ICS_TARGET_API - 1);
308      valid = false;
309    }
310  }
311  return valid;
312}
313
314RSExportForEach *RSExportForEach::Create(RSContext *Context,
315                                         const clang::FunctionDecl *FD) {
316  slangAssert(Context && FD);
317  llvm::StringRef Name = FD->getName();
318  RSExportForEach *FE;
319
320  slangAssert(!Name.empty() && "Function must have a name");
321
322  FE = new RSExportForEach(Context, Name);
323
324  if (!FE->validateAndConstructParams(Context, FD)) {
325    return NULL;
326  }
327
328  clang::ASTContext &Ctx = Context->getASTContext();
329
330  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
331  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
332
333  // Extract the usrData parameter (if we have one)
334  if (FE->mUsrData) {
335    const clang::ParmVarDecl *PVD = FE->mUsrData;
336    clang::QualType QT = PVD->getType().getCanonicalType();
337    slangAssert(QT->isPointerType() &&
338                QT->getPointeeType().isConstQualified());
339
340    const clang::ASTContext &C = Context->getASTContext();
341    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
342        C.VoidTy) {
343      // In the case of using const void*, we can't reflect an appopriate
344      // Java type, so we fall back to just reflecting the ain/aout parameters
345      FE->mUsrData = NULL;
346    } else {
347      clang::RecordDecl *RD =
348          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
349                                    Ctx.getTranslationUnitDecl(),
350                                    clang::SourceLocation(),
351                                    clang::SourceLocation(),
352                                    &Ctx.Idents.get(Id));
353
354      clang::FieldDecl *FD =
355          clang::FieldDecl::Create(Ctx,
356                                   RD,
357                                   clang::SourceLocation(),
358                                   clang::SourceLocation(),
359                                   PVD->getIdentifier(),
360                                   QT->getPointeeType(),
361                                   NULL,
362                                   /* BitWidth = */ NULL,
363                                   /* Mutable = */ false,
364                                   /* HasInit = */ clang::ICIS_NoInit);
365      RD->addDecl(FD);
366      RD->completeDefinition();
367
368      // Create an export type iff we have a valid usrData type
369      clang::QualType T = Ctx.getTagDeclType(RD);
370      slangAssert(!T.isNull());
371
372      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
373
374      if (ET == NULL) {
375        fprintf(stderr, "Failed to export the function %s. There's at least "
376                        "one parameter whose type is not supported by the "
377                        "reflection\n", FE->getName().c_str());
378        return NULL;
379      }
380
381      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
382                  "Parameter packet must be a record");
383
384      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
385    }
386  }
387
388  if (FE->mIn) {
389    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
390    FE->mInType = RSExportType::Create(Context, T);
391    if (FE->mIsKernelStyle) {
392      slangAssert(FE->mInType);
393    }
394  }
395
396  if (FE->mIsKernelStyle && FE->mHasReturnType) {
397    const clang::Type *T = FE->mResultType.getTypePtr();
398    FE->mOutType = RSExportType::Create(Context, T);
399    slangAssert(FE->mOutType);
400  } else if (FE->mOut) {
401    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
402    FE->mOutType = RSExportType::Create(Context, T);
403  }
404
405  return FE;
406}
407
408RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
409  slangAssert(Context);
410  llvm::StringRef Name = "root";
411  RSExportForEach *FE = new RSExportForEach(Context, Name);
412  FE->mDummyRoot = true;
413  return FE;
414}
415
416bool RSExportForEach::isGraphicsRootRSFunc(int targetAPI,
417                                           const clang::FunctionDecl *FD) {
418  if (FD->hasAttr<clang::KernelAttr>()) {
419    return false;
420  }
421
422  if (!isRootRSFunc(FD)) {
423    return false;
424  }
425
426  if (FD->getNumParams() == 0) {
427    // Graphics root function
428    return true;
429  }
430
431  // Check for legacy graphics root function (with single parameter).
432  if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
433    const clang::QualType &IntType = FD->getASTContext().IntTy;
434    if (FD->getResultType().getCanonicalType() == IntType) {
435      return true;
436    }
437  }
438
439  return false;
440}
441
442bool RSExportForEach::isRSForEachFunc(int targetAPI,
443    slang::RSContext* Context,
444    const clang::FunctionDecl *FD) {
445  slangAssert(Context && FD);
446  bool hasKernelAttr = FD->hasAttr<clang::KernelAttr>();
447
448  if (FD->getStorageClass() == clang::SC_Static) {
449    if (hasKernelAttr) {
450      Context->ReportError(FD->getLocation(),
451                           "Invalid use of attribute kernel with "
452                           "static function declaration: %0")
453          << FD->getName();
454    }
455    return false;
456  }
457
458  // Anything tagged as a kernel is definitely used with ForEach.
459  if (hasKernelAttr) {
460    return true;
461  }
462
463  if (isGraphicsRootRSFunc(targetAPI, FD)) {
464    return false;
465  }
466
467  // Check if first parameter is a pointer (which is required for ForEach).
468  unsigned int numParams = FD->getNumParams();
469
470  if (numParams > 0) {
471    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
472    clang::QualType QT = PVD->getType().getCanonicalType();
473
474    if (QT->isPointerType()) {
475      return true;
476    }
477
478    // Any non-graphics root() is automatically a ForEach candidate.
479    // At this point, however, we know that it is not going to be a valid
480    // compute root() function (due to not having a pointer parameter). We
481    // still want to return true here, so that we can issue appropriate
482    // diagnostics.
483    if (isRootRSFunc(FD)) {
484      return true;
485    }
486  }
487
488  return false;
489}
490
491bool
492RSExportForEach::validateSpecialFuncDecl(int targetAPI,
493                                         slang::RSContext *Context,
494                                         clang::FunctionDecl const *FD) {
495  slangAssert(Context && FD);
496  bool valid = true;
497  const clang::ASTContext &C = FD->getASTContext();
498  const clang::QualType &IntType = FD->getASTContext().IntTy;
499
500  if (isGraphicsRootRSFunc(targetAPI, FD)) {
501    if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
502      // Legacy graphics root function
503      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
504      clang::QualType QT = PVD->getType().getCanonicalType();
505      if (QT != IntType) {
506        Context->ReportError(PVD->getLocation(),
507                             "invalid parameter type for legacy "
508                             "graphics root() function: %0")
509            << PVD->getType();
510        valid = false;
511      }
512    }
513
514    // Graphics root function, so verify that it returns an int
515    if (FD->getResultType().getCanonicalType() != IntType) {
516      Context->ReportError(FD->getLocation(),
517                           "root() is required to return "
518                           "an int for graphics usage");
519      valid = false;
520    }
521  } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
522    if (FD->getNumParams() != 0) {
523      Context->ReportError(FD->getLocation(),
524                           "%0(void) is required to have no "
525                           "parameters")
526          << FD->getName();
527      valid = false;
528    }
529
530    if (FD->getResultType().getCanonicalType() != C.VoidTy) {
531      Context->ReportError(FD->getLocation(),
532                           "%0(void) is required to have a void "
533                           "return type")
534          << FD->getName();
535      valid = false;
536    }
537  } else {
538    slangAssert(false && "must be called on root, init or .rs.dtor function!");
539  }
540
541  return valid;
542}
543
544}  // namespace slang
545