1/*
2 * Copyright 2011-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Decl.h"
23#include "clang/AST/TypeLoc.h"
24
25#include "llvm/DerivedTypes.h"
26#include "llvm/Target/TargetData.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31#include "slang_version.h"
32
33namespace slang {
34
35namespace {
36
37static void ReportNameError(clang::DiagnosticsEngine *DiagEngine,
38                            clang::ParmVarDecl const *PVD) {
39  slangAssert(DiagEngine && PVD);
40  const clang::SourceManager &SM = DiagEngine->getSourceManager();
41
42  DiagEngine->Report(
43    clang::FullSourceLoc(PVD->getLocation(), SM),
44    DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
45                                "Duplicate parameter entry "
46                                "(by position/name): '%0'"))
47    << PVD->getName();
48  return;
49}
50
51}  // namespace
52
53
54// This function takes care of additional validation and construction of
55// parameters related to forEach_* reflection.
56bool RSExportForEach::validateAndConstructParams(
57    RSContext *Context, const clang::FunctionDecl *FD) {
58  slangAssert(Context && FD);
59  bool valid = true;
60  clang::ASTContext &C = Context->getASTContext();
61  clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
62
63  numParams = FD->getNumParams();
64
65  if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
66    if (!isRootRSFunc(FD)) {
67      DiagEngine->Report(
68        clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
69        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
70                                    "Non-root compute kernel %0() is "
71                                    "not supported in SDK levels %1-%2"))
72        << FD->getName()
73        << SLANG_MINIMUM_TARGET_API
74        << (SLANG_JB_TARGET_API - 1);
75      return false;
76    }
77  }
78
79  mResultType = FD->getResultType().getCanonicalType();
80  // Compute kernel functions are required to return a void type or
81  // be marked explicitly as a kernel. In the case of
82  // "__attribute__((kernel))", we handle validation differently.
83  if (FD->hasAttr<clang::KernelAttr>()) {
84    return validateAndConstructKernelParams(Context, FD);
85  }
86
87  // If numParams is 0, we already marked this as a graphics root().
88  slangAssert(numParams > 0);
89
90  // Compute kernel functions of this type are required to return a void type.
91  if (mResultType != C.VoidTy) {
92    DiagEngine->Report(
93      clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
94      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
95                                  "Compute kernel %0() is required to return a "
96                                  "void type")) << FD->getName();
97    valid = false;
98  }
99
100  // Validate remaining parameter types
101  // TODO(all): Add support for LOD/face when we have them
102
103  size_t i = 0;
104  const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
105  clang::QualType QT = PVD->getType().getCanonicalType();
106
107  // Check for const T1 *in
108  if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
109    mIn = PVD;
110    i++;  // advance parameter pointer
111  }
112
113  // Check for T2 *out
114  if (i < numParams) {
115    PVD = FD->getParamDecl(i);
116    QT = PVD->getType().getCanonicalType();
117    if (QT->isPointerType() && !QT->getPointeeType().isConstQualified()) {
118      mOut = PVD;
119      i++;  // advance parameter pointer
120    }
121  }
122
123  if (!mIn && !mOut) {
124    DiagEngine->Report(
125      clang::FullSourceLoc(FD->getLocation(),
126                           DiagEngine->getSourceManager()),
127      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
128                                  "Compute kernel %0() must have at least one "
129                                  "parameter for in or out")) << FD->getName();
130    valid = false;
131  }
132
133  // Check for T3 *usrData
134  if (i < numParams) {
135    PVD = FD->getParamDecl(i);
136    QT = PVD->getType().getCanonicalType();
137    if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
138      mUsrData = PVD;
139      i++;  // advance parameter pointer
140    }
141  }
142
143  while (i < numParams) {
144    PVD = FD->getParamDecl(i);
145    QT = PVD->getType().getCanonicalType();
146
147    if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
148      DiagEngine->Report(
149        clang::FullSourceLoc(PVD->getLocation(),
150                             DiagEngine->getSourceManager()),
151        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
152                                    "Unexpected kernel %0() parameter '%1' "
153                                    "of type '%2'"))
154        << FD->getName() << PVD->getName() << PVD->getType().getAsString();
155      valid = false;
156    } else {
157      llvm::StringRef ParamName = PVD->getName();
158      if (ParamName.equals("x")) {
159        if (mX) {
160          ReportNameError(DiagEngine, PVD);
161          valid = false;
162        } else if (mY) {
163          // Can't go back to X after skipping Y
164          ReportNameError(DiagEngine, PVD);
165          valid = false;
166        } else {
167          mX = PVD;
168        }
169      } else if (ParamName.equals("y")) {
170        if (mY) {
171          ReportNameError(DiagEngine, PVD);
172          valid = false;
173        } else {
174          mY = PVD;
175        }
176      } else {
177        if (!mX && !mY) {
178          mX = PVD;
179        } else if (!mY) {
180          mY = PVD;
181        } else {
182          DiagEngine->Report(
183            clang::FullSourceLoc(PVD->getLocation(),
184                                 DiagEngine->getSourceManager()),
185            DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
186                                        "Unexpected kernel %0() parameter '%1' "
187                                        "of type '%2'"))
188            << FD->getName() << PVD->getName() << PVD->getType().getAsString();
189          valid = false;
190        }
191      }
192    }
193
194    i++;
195  }
196
197  mSignatureMetadata = 0;
198  if (valid) {
199    // Set up the bitwise metadata encoding for runtime argument passing.
200    mSignatureMetadata |= (mIn ?       0x01 : 0);
201    mSignatureMetadata |= (mOut ?      0x02 : 0);
202    mSignatureMetadata |= (mUsrData ?  0x04 : 0);
203    mSignatureMetadata |= (mX ?        0x08 : 0);
204    mSignatureMetadata |= (mY ?        0x10 : 0);
205  }
206
207  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
208    // APIs before ICS cannot skip between parameters. It is ok, however, for
209    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
210    if (mSignatureMetadata != 0x1f &&  // In, Out, UsrData, X, Y
211        mSignatureMetadata != 0x0f &&  // In, Out, UsrData, X
212        mSignatureMetadata != 0x07 &&  // In, Out, UsrData
213        mSignatureMetadata != 0x03 &&  // In, Out
214        mSignatureMetadata != 0x01) {  // In
215      DiagEngine->Report(
216        clang::FullSourceLoc(FD->getLocation(),
217                             DiagEngine->getSourceManager()),
218        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
219                                    "Compute kernel %0() targeting SDK levels "
220                                    "%1-%2 may not skip parameters"))
221        << FD->getName() << SLANG_MINIMUM_TARGET_API
222        << (SLANG_ICS_TARGET_API - 1);
223      valid = false;
224    }
225  }
226
227  return valid;
228}
229
230
231bool RSExportForEach::validateAndConstructKernelParams(RSContext *Context,
232    const clang::FunctionDecl *FD) {
233  slangAssert(Context && FD);
234  bool valid = true;
235  clang::ASTContext &C = Context->getASTContext();
236  clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
237
238  if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
239    DiagEngine->Report(
240      clang::FullSourceLoc(FD->getLocation(),
241                           DiagEngine->getSourceManager()),
242      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
243                                  "Compute kernel %0() targeting SDK levels "
244                                  "%1-%2 may not use pass-by-value with "
245                                  "__attribute__((kernel))"))
246      << FD->getName() << SLANG_MINIMUM_TARGET_API
247      << (SLANG_JB_MR1_TARGET_API - 1);
248    return false;
249  }
250
251  // Denote that we are indeed a pass-by-value kernel.
252  mKernel = true;
253
254  if (mResultType != C.VoidTy) {
255    mReturn = true;
256  }
257
258  if (mResultType->isPointerType()) {
259    DiagEngine->Report(
260      clang::FullSourceLoc(FD->getTypeSpecStartLoc(),
261                           DiagEngine->getSourceManager()),
262      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
263                                  "Compute kernel %0() cannot return a "
264                                  "pointer type: '%1'"))
265      << FD->getName() << mResultType.getAsString();
266    valid = false;
267  }
268
269  // Validate remaining parameter types
270  // TODO(all): Add support for LOD/face when we have them
271
272  size_t i = 0;
273  const clang::ParmVarDecl *PVD = NULL;
274  clang::QualType QT;
275
276  if (i < numParams) {
277    PVD = FD->getParamDecl(i);
278    QT = PVD->getType().getCanonicalType();
279
280    if (QT->isPointerType()) {
281      DiagEngine->Report(
282        clang::FullSourceLoc(PVD->getLocation(),
283                             DiagEngine->getSourceManager()),
284        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
285                                    "Compute kernel %0() cannot have "
286                                    "parameter '%1' of pointer type: '%2'"))
287        << FD->getName() << PVD->getName() << PVD->getType().getAsString();
288      valid = false;
289    } else if (QT.getUnqualifiedType() == C.UnsignedIntTy) {
290      // First parameter is either input or x, y (iff it is uint32_t).
291      llvm::StringRef ParamName = PVD->getName();
292      if (ParamName.equals("x")) {
293        mX = PVD;
294      } else if (ParamName.equals("y")) {
295        mY = PVD;
296      } else {
297        mIn = PVD;
298      }
299    } else {
300      mIn = PVD;
301    }
302
303    i++;  // advance parameter pointer
304  }
305
306  // Check that we have at least one allocation to use for dimensions.
307  if (valid && !mIn && !mReturn) {
308    DiagEngine->Report(
309      clang::FullSourceLoc(FD->getLocation(),
310                           DiagEngine->getSourceManager()),
311      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
312                                  "Compute kernel %0() must have at least one "
313                                  "input parameter or a non-void return "
314                                  "type")) << FD->getName();
315    valid = false;
316  }
317
318  // TODO: Abstract this block away, since it is duplicate code.
319  while (i < numParams) {
320    PVD = FD->getParamDecl(i);
321    QT = PVD->getType().getCanonicalType();
322
323    if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
324      DiagEngine->Report(
325        clang::FullSourceLoc(PVD->getLocation(),
326                             DiagEngine->getSourceManager()),
327        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
328                                    "Unexpected kernel %0() parameter '%1' "
329                                    "of type '%2'"))
330        << FD->getName() << PVD->getName() << PVD->getType().getAsString();
331      valid = false;
332    } else {
333      llvm::StringRef ParamName = PVD->getName();
334      if (ParamName.equals("x")) {
335        if (mX) {
336          ReportNameError(DiagEngine, PVD);
337          valid = false;
338        } else if (mY) {
339          // Can't go back to X after skipping Y
340          ReportNameError(DiagEngine, PVD);
341          valid = false;
342        } else {
343          mX = PVD;
344        }
345      } else if (ParamName.equals("y")) {
346        if (mY) {
347          ReportNameError(DiagEngine, PVD);
348          valid = false;
349        } else {
350          mY = PVD;
351        }
352      } else {
353        if (!mX && !mY) {
354          mX = PVD;
355        } else if (!mY) {
356          mY = PVD;
357        } else {
358          DiagEngine->Report(
359            clang::FullSourceLoc(PVD->getLocation(),
360                                 DiagEngine->getSourceManager()),
361            DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
362                                        "Unexpected kernel %0() parameter '%1' "
363                                        "of type '%2'"))
364            << FD->getName() << PVD->getName() << PVD->getType().getAsString();
365          valid = false;
366        }
367      }
368    }
369
370    i++;  // advance parameter pointer
371  }
372
373  mSignatureMetadata = 0;
374  if (valid) {
375    // Set up the bitwise metadata encoding for runtime argument passing.
376    mSignatureMetadata |= (mIn ?       0x01 : 0);
377    slangAssert(mOut == NULL);
378    mSignatureMetadata |= (mReturn ?   0x02 : 0);
379    slangAssert(mUsrData == NULL);
380    mSignatureMetadata |= (mUsrData ?  0x04 : 0);
381    mSignatureMetadata |= (mX ?        0x08 : 0);
382    mSignatureMetadata |= (mY ?        0x10 : 0);
383    mSignatureMetadata |= (mKernel ?   0x20 : 0);  // pass-by-value
384  }
385
386  return valid;
387}
388
389
390RSExportForEach *RSExportForEach::Create(RSContext *Context,
391                                         const clang::FunctionDecl *FD) {
392  slangAssert(Context && FD);
393  llvm::StringRef Name = FD->getName();
394  RSExportForEach *FE;
395
396  slangAssert(!Name.empty() && "Function must have a name");
397
398  FE = new RSExportForEach(Context, Name);
399
400  if (!FE->validateAndConstructParams(Context, FD)) {
401    return NULL;
402  }
403
404  clang::ASTContext &Ctx = Context->getASTContext();
405
406  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
407  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
408
409  // Extract the usrData parameter (if we have one)
410  if (FE->mUsrData) {
411    const clang::ParmVarDecl *PVD = FE->mUsrData;
412    clang::QualType QT = PVD->getType().getCanonicalType();
413    slangAssert(QT->isPointerType() &&
414                QT->getPointeeType().isConstQualified());
415
416    const clang::ASTContext &C = Context->getASTContext();
417    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
418        C.VoidTy) {
419      // In the case of using const void*, we can't reflect an appopriate
420      // Java type, so we fall back to just reflecting the ain/aout parameters
421      FE->mUsrData = NULL;
422    } else {
423      clang::RecordDecl *RD =
424          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
425                                    Ctx.getTranslationUnitDecl(),
426                                    clang::SourceLocation(),
427                                    clang::SourceLocation(),
428                                    &Ctx.Idents.get(Id));
429
430      clang::FieldDecl *FD =
431          clang::FieldDecl::Create(Ctx,
432                                   RD,
433                                   clang::SourceLocation(),
434                                   clang::SourceLocation(),
435                                   PVD->getIdentifier(),
436                                   QT->getPointeeType(),
437                                   NULL,
438                                   /* BitWidth = */ NULL,
439                                   /* Mutable = */ false,
440                                   /* HasInit = */ clang::ICIS_NoInit);
441      RD->addDecl(FD);
442      RD->completeDefinition();
443
444      // Create an export type iff we have a valid usrData type
445      clang::QualType T = Ctx.getTagDeclType(RD);
446      slangAssert(!T.isNull());
447
448      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
449
450      if (ET == NULL) {
451        fprintf(stderr, "Failed to export the function %s. There's at least "
452                        "one parameter whose type is not supported by the "
453                        "reflection\n", FE->getName().c_str());
454        return NULL;
455      }
456
457      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
458                  "Parameter packet must be a record");
459
460      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
461    }
462  }
463
464  if (FE->mIn) {
465    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
466    FE->mInType = RSExportType::Create(Context, T);
467    if (FE->mKernel) {
468      slangAssert(FE->mInType);
469    }
470  }
471
472  if (FE->mKernel && FE->mReturn) {
473    const clang::Type *T = FE->mResultType.getTypePtr();
474    FE->mOutType = RSExportType::Create(Context, T);
475    slangAssert(FE->mOutType);
476  } else if (FE->mOut) {
477    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
478    FE->mOutType = RSExportType::Create(Context, T);
479  }
480
481  return FE;
482}
483
484RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
485  slangAssert(Context);
486  llvm::StringRef Name = "root";
487  RSExportForEach *FE = new RSExportForEach(Context, Name);
488  FE->mDummyRoot = true;
489  return FE;
490}
491
492bool RSExportForEach::isGraphicsRootRSFunc(int targetAPI,
493                                           const clang::FunctionDecl *FD) {
494  if (FD->hasAttr<clang::KernelAttr>()) {
495    return false;
496  }
497
498  if (!isRootRSFunc(FD)) {
499    return false;
500  }
501
502  if (FD->getNumParams() == 0) {
503    // Graphics root function
504    return true;
505  }
506
507  // Check for legacy graphics root function (with single parameter).
508  if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
509    const clang::QualType &IntType = FD->getASTContext().IntTy;
510    if (FD->getResultType().getCanonicalType() == IntType) {
511      return true;
512    }
513  }
514
515  return false;
516}
517
518bool RSExportForEach::isRSForEachFunc(int targetAPI,
519    const clang::FunctionDecl *FD) {
520  // Anything tagged as a kernel is definitely used with ForEach.
521  if (FD->hasAttr<clang::KernelAttr>()) {
522    return true;
523  }
524
525  if (isGraphicsRootRSFunc(targetAPI, FD)) {
526    return false;
527  }
528
529  // Check if first parameter is a pointer (which is required for ForEach).
530  unsigned int numParams = FD->getNumParams();
531
532  if (numParams > 0) {
533    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
534    clang::QualType QT = PVD->getType().getCanonicalType();
535
536    if (QT->isPointerType()) {
537      return true;
538    }
539
540    // Any non-graphics root() is automatically a ForEach candidate.
541    // At this point, however, we know that it is not going to be a valid
542    // compute root() function (due to not having a pointer parameter). We
543    // still want to return true here, so that we can issue appropriate
544    // diagnostics.
545    if (isRootRSFunc(FD)) {
546      return true;
547    }
548  }
549
550  return false;
551}
552
553bool
554RSExportForEach::validateSpecialFuncDecl(int targetAPI,
555                                         clang::DiagnosticsEngine *DiagEngine,
556                                         clang::FunctionDecl const *FD) {
557  slangAssert(DiagEngine && FD);
558  bool valid = true;
559  const clang::ASTContext &C = FD->getASTContext();
560  const clang::QualType &IntType = FD->getASTContext().IntTy;
561
562  if (isGraphicsRootRSFunc(targetAPI, FD)) {
563    if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
564      // Legacy graphics root function
565      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
566      clang::QualType QT = PVD->getType().getCanonicalType();
567      if (QT != IntType) {
568        DiagEngine->Report(
569          clang::FullSourceLoc(PVD->getLocation(),
570                               DiagEngine->getSourceManager()),
571          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
572                                      "invalid parameter type for legacy "
573                                      "graphics root() function: %0"))
574          << PVD->getType();
575        valid = false;
576      }
577    }
578
579    // Graphics root function, so verify that it returns an int
580    if (FD->getResultType().getCanonicalType() != IntType) {
581      DiagEngine->Report(
582        clang::FullSourceLoc(FD->getLocation(),
583                             DiagEngine->getSourceManager()),
584        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
585                                    "root() is required to return "
586                                    "an int for graphics usage"));
587      valid = false;
588    }
589  } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
590    if (FD->getNumParams() != 0) {
591      DiagEngine->Report(
592          clang::FullSourceLoc(FD->getLocation(),
593                               DiagEngine->getSourceManager()),
594          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
595                                      "%0(void) is required to have no "
596                                      "parameters")) << FD->getName();
597      valid = false;
598    }
599
600    if (FD->getResultType().getCanonicalType() != C.VoidTy) {
601      DiagEngine->Report(
602          clang::FullSourceLoc(FD->getLocation(),
603                               DiagEngine->getSourceManager()),
604          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
605                                      "%0(void) is required to have a void "
606                                      "return type")) << FD->getName();
607      valid = false;
608    }
609  } else {
610    slangAssert(false && "must be called on root, init or .rs.dtor function!");
611  }
612
613  return valid;
614}
615
616}  // namespace slang
617