slang_rs_export_foreach.cpp revision 0f2a2397df53a1bb74609abe3c27719bc7e3c328
1/*
2 * Copyright 2011-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Attr.h"
23#include "clang/AST/Decl.h"
24#include "clang/AST/TypeLoc.h"
25
26#include "llvm/IR/DerivedTypes.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31#include "slang_version.h"
32
33namespace slang {
34
35namespace {
36
37static void ReportNameError(clang::DiagnosticsEngine *DiagEngine,
38                            clang::ParmVarDecl const *PVD) {
39  slangAssert(DiagEngine && PVD);
40  const clang::SourceManager &SM = DiagEngine->getSourceManager();
41
42  DiagEngine->Report(
43    clang::FullSourceLoc(PVD->getLocation(), SM),
44    DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
45                                "Duplicate parameter entry "
46                                "(by position/name): '%0'"))
47    << PVD->getName();
48  return;
49}
50
51}  // namespace
52
53
54// This function takes care of additional validation and construction of
55// parameters related to forEach_* reflection.
56bool RSExportForEach::validateAndConstructParams(
57    RSContext *Context, const clang::FunctionDecl *FD) {
58  slangAssert(Context && FD);
59  bool valid = true;
60  clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
61
62  numParams = FD->getNumParams();
63
64  if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
65    if (!isRootRSFunc(FD)) {
66      DiagEngine->Report(
67        clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
68        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
69                                    "Non-root compute kernel %0() is "
70                                    "not supported in SDK levels %1-%2"))
71        << FD->getName()
72        << SLANG_MINIMUM_TARGET_API
73        << (SLANG_JB_TARGET_API - 1);
74      return false;
75    }
76  }
77
78  mResultType = FD->getResultType().getCanonicalType();
79  // Compute kernel functions are defined differently when the
80  // "__attribute__((kernel))" is set.
81  if (FD->hasAttr<clang::KernelAttr>()) {
82    valid |= validateAndConstructKernelParams(Context, FD);
83  } else {
84    valid |= validateAndConstructOldStyleParams(Context, FD);
85  }
86  valid |= setSignatureMetadata(Context, FD);
87  return valid;
88}
89
90bool RSExportForEach::validateAndConstructOldStyleParams(RSContext *Context,
91    const clang::FunctionDecl *FD) {
92  slangAssert(Context && FD);
93  // If numParams is 0, we already marked this as a graphics root().
94  slangAssert(numParams > 0);
95
96  bool valid = true;
97  clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
98
99  // Compute kernel functions of this style are required to return a void type.
100  clang::ASTContext &C = Context->getASTContext();
101  if (mResultType != C.VoidTy) {
102    DiagEngine->Report(
103      clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
104      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
105                                  "Compute kernel %0() is required to return a "
106                                  "void type")) << FD->getName();
107    valid = false;
108  }
109
110  // Validate remaining parameter types
111  // TODO(all): Add support for LOD/face when we have them
112
113  size_t i = 0;
114  const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
115  clang::QualType QT = PVD->getType().getCanonicalType();
116
117  // Check for const T1 *in
118  if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
119    mIn = PVD;
120    i++;  // advance parameter pointer
121  }
122
123  // Check for T2 *out
124  if (i < numParams) {
125    PVD = FD->getParamDecl(i);
126    QT = PVD->getType().getCanonicalType();
127    if (QT->isPointerType() && !QT->getPointeeType().isConstQualified()) {
128      mOut = PVD;
129      i++;  // advance parameter pointer
130    }
131  }
132
133  if (!mIn && !mOut) {
134    DiagEngine->Report(
135      clang::FullSourceLoc(FD->getLocation(),
136                           DiagEngine->getSourceManager()),
137      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
138                                  "Compute kernel %0() must have at least one "
139                                  "parameter for in or out")) << FD->getName();
140    valid = false;
141  }
142
143  // Check for T3 *usrData
144  if (i < numParams) {
145    PVD = FD->getParamDecl(i);
146    QT = PVD->getType().getCanonicalType();
147    if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
148      mUsrData = PVD;
149      i++;  // advance parameter pointer
150    }
151  }
152
153  while (i < numParams) {
154    PVD = FD->getParamDecl(i);
155    QT = PVD->getType().getCanonicalType();
156
157    if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
158      DiagEngine->Report(
159        clang::FullSourceLoc(PVD->getLocation(),
160                             DiagEngine->getSourceManager()),
161        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
162                                    "Unexpected kernel %0() parameter '%1' "
163                                    "of type '%2'"))
164        << FD->getName() << PVD->getName() << PVD->getType().getAsString();
165      valid = false;
166    } else {
167      llvm::StringRef ParamName = PVD->getName();
168      if (ParamName.equals("x")) {
169        if (mX) {
170          ReportNameError(DiagEngine, PVD);
171          valid = false;
172        } else if (mY) {
173          // Can't go back to X after skipping Y
174          ReportNameError(DiagEngine, PVD);
175          valid = false;
176        } else {
177          mX = PVD;
178        }
179      } else if (ParamName.equals("y")) {
180        if (mY) {
181          ReportNameError(DiagEngine, PVD);
182          valid = false;
183        } else {
184          mY = PVD;
185        }
186      } else {
187        if (!mX && !mY) {
188          mX = PVD;
189        } else if (!mY) {
190          mY = PVD;
191        } else {
192          DiagEngine->Report(
193            clang::FullSourceLoc(PVD->getLocation(),
194                                 DiagEngine->getSourceManager()),
195            DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
196                                        "Unexpected kernel %0() parameter '%1' "
197                                        "of type '%2'"))
198            << FD->getName() << PVD->getName() << PVD->getType().getAsString();
199          valid = false;
200        }
201      }
202    }
203
204    i++;
205  }
206
207  return valid;
208}
209
210
211bool RSExportForEach::validateAndConstructKernelParams(RSContext *Context,
212    const clang::FunctionDecl *FD) {
213  slangAssert(Context && FD);
214  bool valid = true;
215  clang::ASTContext &C = Context->getASTContext();
216  clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
217
218  if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
219    DiagEngine->Report(
220      clang::FullSourceLoc(FD->getLocation(),
221                           DiagEngine->getSourceManager()),
222      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
223                                  "Compute kernel %0() targeting SDK levels "
224                                  "%1-%2 may not use pass-by-value with "
225                                  "__attribute__((kernel))"))
226      << FD->getName() << SLANG_MINIMUM_TARGET_API
227      << (SLANG_JB_MR1_TARGET_API - 1);
228    return false;
229  }
230
231  // Denote that we are indeed a pass-by-value kernel.
232  mIsKernelStyle = true;
233
234  if (mResultType != C.VoidTy) {
235    mHasReturnType = true;
236  }
237
238  if (mResultType->isPointerType()) {
239    DiagEngine->Report(
240      clang::FullSourceLoc(FD->getTypeSpecStartLoc(),
241                           DiagEngine->getSourceManager()),
242      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
243                                  "Compute kernel %0() cannot return a "
244                                  "pointer type: '%1'"))
245      << FD->getName() << mResultType.getAsString();
246    valid = false;
247  }
248
249  // Validate remaining parameter types
250  // TODO(all): Add support for LOD/face when we have them
251
252  size_t i = 0;
253  const clang::ParmVarDecl *PVD = NULL;
254  clang::QualType QT;
255
256  if (i < numParams) {
257    PVD = FD->getParamDecl(i);
258    QT = PVD->getType().getCanonicalType();
259
260    if (QT->isPointerType()) {
261      DiagEngine->Report(
262        clang::FullSourceLoc(PVD->getLocation(),
263                             DiagEngine->getSourceManager()),
264        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
265                                    "Compute kernel %0() cannot have "
266                                    "parameter '%1' of pointer type: '%2'"))
267        << FD->getName() << PVD->getName() << PVD->getType().getAsString();
268      valid = false;
269    } else if (QT.getUnqualifiedType() == C.UnsignedIntTy) {
270      // First parameter is either input or x, y (iff it is uint32_t).
271      llvm::StringRef ParamName = PVD->getName();
272      if (ParamName.equals("x")) {
273        mX = PVD;
274      } else if (ParamName.equals("y")) {
275        mY = PVD;
276      } else {
277        mIn = PVD;
278      }
279    } else {
280      mIn = PVD;
281    }
282
283    i++;  // advance parameter pointer
284  }
285
286  // Check that we have at least one allocation to use for dimensions.
287  if (valid && !mIn && !mHasReturnType) {
288    DiagEngine->Report(
289      clang::FullSourceLoc(FD->getLocation(),
290                           DiagEngine->getSourceManager()),
291      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
292                                  "Compute kernel %0() must have at least one "
293                                  "input parameter or a non-void return "
294                                  "type")) << FD->getName();
295    valid = false;
296  }
297
298  // TODO: Abstract this block away, since it is duplicate code.
299  while (i < numParams) {
300    PVD = FD->getParamDecl(i);
301    QT = PVD->getType().getCanonicalType();
302
303    if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
304      DiagEngine->Report(
305        clang::FullSourceLoc(PVD->getLocation(),
306                             DiagEngine->getSourceManager()),
307        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
308                                    "Unexpected kernel %0() parameter '%1' "
309                                    "of type '%2'"))
310        << FD->getName() << PVD->getName() << PVD->getType().getAsString();
311      valid = false;
312    } else {
313      llvm::StringRef ParamName = PVD->getName();
314      if (ParamName.equals("x")) {
315        if (mX) {
316          ReportNameError(DiagEngine, PVD);
317          valid = false;
318        } else if (mY) {
319          // Can't go back to X after skipping Y
320          ReportNameError(DiagEngine, PVD);
321          valid = false;
322        } else {
323          mX = PVD;
324        }
325      } else if (ParamName.equals("y")) {
326        if (mY) {
327          ReportNameError(DiagEngine, PVD);
328          valid = false;
329        } else {
330          mY = PVD;
331        }
332      } else {
333        if (!mX && !mY) {
334          mX = PVD;
335        } else if (!mY) {
336          mY = PVD;
337        } else {
338          DiagEngine->Report(
339            clang::FullSourceLoc(PVD->getLocation(),
340                                 DiagEngine->getSourceManager()),
341            DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
342                                        "Unexpected kernel %0() parameter '%1' "
343                                        "of type '%2'"))
344            << FD->getName() << PVD->getName() << PVD->getType().getAsString();
345          valid = false;
346        }
347      }
348    }
349
350    i++;  // advance parameter pointer
351  }
352  return valid;
353}
354
355bool RSExportForEach::setSignatureMetadata(RSContext *Context,
356                                           const clang::FunctionDecl *FD) {
357  mSignatureMetadata = 0;
358  bool valid = true;
359
360  if (mIsKernelStyle) {
361    slangAssert(mOut == NULL);
362    slangAssert(mUsrData == NULL);
363  } else {
364    slangAssert(!mHasReturnType);
365  }
366
367  // Set up the bitwise metadata encoding for runtime argument passing.
368  // TODO: If this bit field is re-used from C++ code, define the values in a header.
369  const bool HasOut = mOut || mHasReturnType;
370  mSignatureMetadata |= (mIn ?            0x01 : 0);
371  mSignatureMetadata |= (HasOut ?         0x02 : 0);
372  mSignatureMetadata |= (mUsrData ?       0x04 : 0);
373  mSignatureMetadata |= (mX ?             0x08 : 0);
374  mSignatureMetadata |= (mY ?             0x10 : 0);
375  mSignatureMetadata |= (mIsKernelStyle ? 0x20 : 0);  // pass-by-value
376
377  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
378    // APIs before ICS cannot skip between parameters. It is ok, however, for
379    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
380    if (mSignatureMetadata != 0x1f &&  // In, Out, UsrData, X, Y
381        mSignatureMetadata != 0x0f &&  // In, Out, UsrData, X
382        mSignatureMetadata != 0x07 &&  // In, Out, UsrData
383        mSignatureMetadata != 0x03 &&  // In, Out
384        mSignatureMetadata != 0x01) {  // In
385      clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
386      DiagEngine->Report(clang::FullSourceLoc(FD->getLocation(),
387                                              DiagEngine->getSourceManager()),
388                         DiagEngine->getCustomDiagID(
389                             clang::DiagnosticsEngine::Error,
390                             "Compute kernel %0() targeting SDK levels "
391                             "%1-%2 may not skip parameters"))
392          << FD->getName() << SLANG_MINIMUM_TARGET_API
393          << (SLANG_ICS_TARGET_API - 1);
394      valid = false;
395    }
396  }
397  return valid;
398}
399
400RSExportForEach *RSExportForEach::Create(RSContext *Context,
401                                         const clang::FunctionDecl *FD) {
402  slangAssert(Context && FD);
403  llvm::StringRef Name = FD->getName();
404  RSExportForEach *FE;
405
406  slangAssert(!Name.empty() && "Function must have a name");
407
408  FE = new RSExportForEach(Context, Name);
409
410  if (!FE->validateAndConstructParams(Context, FD)) {
411    return NULL;
412  }
413
414  clang::ASTContext &Ctx = Context->getASTContext();
415
416  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
417  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
418
419  // Extract the usrData parameter (if we have one)
420  if (FE->mUsrData) {
421    const clang::ParmVarDecl *PVD = FE->mUsrData;
422    clang::QualType QT = PVD->getType().getCanonicalType();
423    slangAssert(QT->isPointerType() &&
424                QT->getPointeeType().isConstQualified());
425
426    const clang::ASTContext &C = Context->getASTContext();
427    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
428        C.VoidTy) {
429      // In the case of using const void*, we can't reflect an appopriate
430      // Java type, so we fall back to just reflecting the ain/aout parameters
431      FE->mUsrData = NULL;
432    } else {
433      clang::RecordDecl *RD =
434          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
435                                    Ctx.getTranslationUnitDecl(),
436                                    clang::SourceLocation(),
437                                    clang::SourceLocation(),
438                                    &Ctx.Idents.get(Id));
439
440      clang::FieldDecl *FD =
441          clang::FieldDecl::Create(Ctx,
442                                   RD,
443                                   clang::SourceLocation(),
444                                   clang::SourceLocation(),
445                                   PVD->getIdentifier(),
446                                   QT->getPointeeType(),
447                                   NULL,
448                                   /* BitWidth = */ NULL,
449                                   /* Mutable = */ false,
450                                   /* HasInit = */ clang::ICIS_NoInit);
451      RD->addDecl(FD);
452      RD->completeDefinition();
453
454      // Create an export type iff we have a valid usrData type
455      clang::QualType T = Ctx.getTagDeclType(RD);
456      slangAssert(!T.isNull());
457
458      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
459
460      if (ET == NULL) {
461        fprintf(stderr, "Failed to export the function %s. There's at least "
462                        "one parameter whose type is not supported by the "
463                        "reflection\n", FE->getName().c_str());
464        return NULL;
465      }
466
467      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
468                  "Parameter packet must be a record");
469
470      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
471    }
472  }
473
474  if (FE->mIn) {
475    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
476    FE->mInType = RSExportType::Create(Context, T);
477    if (FE->mIsKernelStyle) {
478      slangAssert(FE->mInType);
479    }
480  }
481
482  if (FE->mIsKernelStyle && FE->mHasReturnType) {
483    const clang::Type *T = FE->mResultType.getTypePtr();
484    FE->mOutType = RSExportType::Create(Context, T);
485    slangAssert(FE->mOutType);
486  } else if (FE->mOut) {
487    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
488    FE->mOutType = RSExportType::Create(Context, T);
489  }
490
491  return FE;
492}
493
494RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
495  slangAssert(Context);
496  llvm::StringRef Name = "root";
497  RSExportForEach *FE = new RSExportForEach(Context, Name);
498  FE->mDummyRoot = true;
499  return FE;
500}
501
502bool RSExportForEach::isGraphicsRootRSFunc(int targetAPI,
503                                           const clang::FunctionDecl *FD) {
504  if (FD->hasAttr<clang::KernelAttr>()) {
505    return false;
506  }
507
508  if (!isRootRSFunc(FD)) {
509    return false;
510  }
511
512  if (FD->getNumParams() == 0) {
513    // Graphics root function
514    return true;
515  }
516
517  // Check for legacy graphics root function (with single parameter).
518  if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
519    const clang::QualType &IntType = FD->getASTContext().IntTy;
520    if (FD->getResultType().getCanonicalType() == IntType) {
521      return true;
522    }
523  }
524
525  return false;
526}
527
528bool RSExportForEach::isRSForEachFunc(int targetAPI,
529    clang::DiagnosticsEngine *DiagEngine,
530    const clang::FunctionDecl *FD) {
531  slangAssert(DiagEngine && FD);
532  bool hasKernelAttr = FD->hasAttr<clang::KernelAttr>();
533
534  if (FD->getStorageClass() == clang::SC_Static) {
535    if (hasKernelAttr) {
536      DiagEngine->Report(
537        clang::FullSourceLoc(FD->getLocation(),
538                             DiagEngine->getSourceManager()),
539        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
540                                    "Invalid use of attribute kernel with "
541                                    "static function declaration: %0"))
542        << FD->getName();
543    }
544    return false;
545  }
546
547  // Anything tagged as a kernel is definitely used with ForEach.
548  if (hasKernelAttr) {
549    return true;
550  }
551
552  if (isGraphicsRootRSFunc(targetAPI, FD)) {
553    return false;
554  }
555
556  // Check if first parameter is a pointer (which is required for ForEach).
557  unsigned int numParams = FD->getNumParams();
558
559  if (numParams > 0) {
560    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
561    clang::QualType QT = PVD->getType().getCanonicalType();
562
563    if (QT->isPointerType()) {
564      return true;
565    }
566
567    // Any non-graphics root() is automatically a ForEach candidate.
568    // At this point, however, we know that it is not going to be a valid
569    // compute root() function (due to not having a pointer parameter). We
570    // still want to return true here, so that we can issue appropriate
571    // diagnostics.
572    if (isRootRSFunc(FD)) {
573      return true;
574    }
575  }
576
577  return false;
578}
579
580bool
581RSExportForEach::validateSpecialFuncDecl(int targetAPI,
582                                         clang::DiagnosticsEngine *DiagEngine,
583                                         clang::FunctionDecl const *FD) {
584  slangAssert(DiagEngine && FD);
585  bool valid = true;
586  const clang::ASTContext &C = FD->getASTContext();
587  const clang::QualType &IntType = FD->getASTContext().IntTy;
588
589  if (isGraphicsRootRSFunc(targetAPI, FD)) {
590    if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
591      // Legacy graphics root function
592      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
593      clang::QualType QT = PVD->getType().getCanonicalType();
594      if (QT != IntType) {
595        DiagEngine->Report(
596          clang::FullSourceLoc(PVD->getLocation(),
597                               DiagEngine->getSourceManager()),
598          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
599                                      "invalid parameter type for legacy "
600                                      "graphics root() function: %0"))
601          << PVD->getType();
602        valid = false;
603      }
604    }
605
606    // Graphics root function, so verify that it returns an int
607    if (FD->getResultType().getCanonicalType() != IntType) {
608      DiagEngine->Report(
609        clang::FullSourceLoc(FD->getLocation(),
610                             DiagEngine->getSourceManager()),
611        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
612                                    "root() is required to return "
613                                    "an int for graphics usage"));
614      valid = false;
615    }
616  } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
617    if (FD->getNumParams() != 0) {
618      DiagEngine->Report(
619          clang::FullSourceLoc(FD->getLocation(),
620                               DiagEngine->getSourceManager()),
621          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
622                                      "%0(void) is required to have no "
623                                      "parameters")) << FD->getName();
624      valid = false;
625    }
626
627    if (FD->getResultType().getCanonicalType() != C.VoidTy) {
628      DiagEngine->Report(
629          clang::FullSourceLoc(FD->getLocation(),
630                               DiagEngine->getSourceManager()),
631          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
632                                      "%0(void) is required to have a void "
633                                      "return type")) << FD->getName();
634      valid = false;
635    }
636  } else {
637    slangAssert(false && "must be called on root, init or .rs.dtor function!");
638  }
639
640  return valid;
641}
642
643}  // namespace slang
644