slang_rs_export_foreach.cpp revision 7b51b55e4467605a599e868a0dde7cb95c5ab76e
1/*
2 * Copyright 2011-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Decl.h"
23#include "clang/AST/TypeLoc.h"
24
25#include "llvm/DerivedTypes.h"
26#include "llvm/Target/TargetData.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31#include "slang_version.h"
32
33namespace slang {
34
35namespace {
36
37static void ReportNameError(clang::DiagnosticsEngine *DiagEngine,
38                            clang::ParmVarDecl const *PVD) {
39  slangAssert(DiagEngine && PVD);
40  const clang::SourceManager &SM = DiagEngine->getSourceManager();
41
42  DiagEngine->Report(
43    clang::FullSourceLoc(PVD->getLocation(), SM),
44    DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
45                                "Duplicate parameter entry "
46                                "(by position/name): '%0'"))
47    << PVD->getName();
48  return;
49}
50
51}  // namespace
52
53// This function takes care of additional validation and construction of
54// parameters related to forEach_* reflection.
55bool RSExportForEach::validateAndConstructParams(
56    RSContext *Context, const clang::FunctionDecl *FD) {
57  slangAssert(Context && FD);
58  bool valid = true;
59  clang::ASTContext &C = Context->getASTContext();
60  clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
61
62  numParams = FD->getNumParams();
63  slangAssert(numParams > 0);
64
65  if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
66    if (!isRootRSFunc(FD)) {
67      DiagEngine->Report(
68        clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
69        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
70                                    "Non-root compute kernel %0() is "
71                                    "not supported in SDK levels %1-%2"))
72        << FD->getName()
73        << SLANG_MINIMUM_TARGET_API
74        << (SLANG_JB_TARGET_API - 1);
75      return false;
76    }
77  }
78
79  // Compute kernel functions are required to return a void type for now
80  if (FD->getResultType().getCanonicalType() != C.VoidTy) {
81    DiagEngine->Report(
82      clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
83      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
84                                  "Compute kernel %0() is required to return a "
85                                  "void type")) << FD->getName();
86    valid = false;
87  }
88
89  // Validate remaining parameter types
90  // TODO(all): Add support for LOD/face when we have them
91
92  size_t i = 0;
93  const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
94  clang::QualType QT = PVD->getType().getCanonicalType();
95
96  // Check for const T1 *in
97  if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
98    mIn = PVD;
99    i++;  // advance parameter pointer
100  }
101
102  // Check for T2 *out
103  if (i < numParams) {
104    PVD = FD->getParamDecl(i);
105    QT = PVD->getType().getCanonicalType();
106    if (QT->isPointerType() && !QT->getPointeeType().isConstQualified()) {
107      mOut = PVD;
108      i++;  // advance parameter pointer
109    }
110  }
111
112  if (!mIn && !mOut) {
113    DiagEngine->Report(
114      clang::FullSourceLoc(FD->getLocation(),
115                           DiagEngine->getSourceManager()),
116      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
117                                  "Compute kernel %0() must have at least one "
118                                  "parameter for in or out")) << FD->getName();
119    valid = false;
120  }
121
122  // Check for T3 *usrData
123  if (i < numParams) {
124    PVD = FD->getParamDecl(i);
125    QT = PVD->getType().getCanonicalType();
126    if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
127      mUsrData = PVD;
128      i++;  // advance parameter pointer
129    }
130  }
131
132  while (i < numParams) {
133    PVD = FD->getParamDecl(i);
134    QT = PVD->getType().getCanonicalType();
135
136    if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
137      DiagEngine->Report(
138        clang::FullSourceLoc(PVD->getLocation(),
139                             DiagEngine->getSourceManager()),
140        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
141                                    "Unexpected kernel %0() parameter '%1' "
142                                    "of type '%2'"))
143        << FD->getName() << PVD->getName() << PVD->getType().getAsString();
144      valid = false;
145    } else {
146      llvm::StringRef ParamName = PVD->getName();
147      if (ParamName.equals("x")) {
148        if (mX) {
149          ReportNameError(DiagEngine, PVD);
150          valid = false;
151        } else if (mY) {
152          // Can't go back to X after skipping Y
153          ReportNameError(DiagEngine, PVD);
154          valid = false;
155        } else {
156          mX = PVD;
157        }
158      } else if (ParamName.equals("y")) {
159        if (mY) {
160          ReportNameError(DiagEngine, PVD);
161          valid = false;
162        } else {
163          mY = PVD;
164        }
165      } else {
166        if (!mX && !mY) {
167          mX = PVD;
168        } else if (!mY) {
169          mY = PVD;
170        } else {
171          DiagEngine->Report(
172            clang::FullSourceLoc(PVD->getLocation(),
173                                 DiagEngine->getSourceManager()),
174            DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
175                                        "Unexpected kernel %0() parameter '%1' "
176                                        "of type '%2'"))
177            << FD->getName() << PVD->getName() << PVD->getType().getAsString();
178          valid = false;
179        }
180      }
181    }
182
183    i++;
184  }
185
186  mSignatureMetadata = 0;
187  if (valid) {
188    // Set up the bitwise metadata encoding for runtime argument passing.
189    mSignatureMetadata |= (mIn ?       0x01 : 0);
190    mSignatureMetadata |= (mOut ?      0x02 : 0);
191    mSignatureMetadata |= (mUsrData ?  0x04 : 0);
192    mSignatureMetadata |= (mX ?        0x08 : 0);
193    mSignatureMetadata |= (mY ?        0x10 : 0);
194  }
195
196  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
197    // APIs before ICS cannot skip between parameters. It is ok, however, for
198    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
199    if (mSignatureMetadata != 0x1f &&  // In, Out, UsrData, X, Y
200        mSignatureMetadata != 0x0f &&  // In, Out, UsrData, X
201        mSignatureMetadata != 0x07 &&  // In, Out, UsrData
202        mSignatureMetadata != 0x03 &&  // In, Out
203        mSignatureMetadata != 0x01) {  // In
204      DiagEngine->Report(
205        clang::FullSourceLoc(FD->getLocation(),
206                             DiagEngine->getSourceManager()),
207        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
208                                    "Compute kernel %0() targeting SDK levels "
209                                    "%1-%2 may not skip parameters"))
210        << FD->getName() << SLANG_MINIMUM_TARGET_API
211        << (SLANG_ICS_TARGET_API - 1);
212      valid = false;
213    }
214  }
215
216  return valid;
217}
218
219RSExportForEach *RSExportForEach::Create(RSContext *Context,
220                                         const clang::FunctionDecl *FD) {
221  slangAssert(Context && FD);
222  llvm::StringRef Name = FD->getName();
223  RSExportForEach *FE;
224
225  slangAssert(!Name.empty() && "Function must have a name");
226
227  FE = new RSExportForEach(Context, Name, FD);
228
229  if (!FE->validateAndConstructParams(Context, FD)) {
230    return NULL;
231  }
232
233  clang::ASTContext &Ctx = Context->getASTContext();
234
235  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
236  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
237
238  // Extract the usrData parameter (if we have one)
239  if (FE->mUsrData) {
240    const clang::ParmVarDecl *PVD = FE->mUsrData;
241    clang::QualType QT = PVD->getType().getCanonicalType();
242    slangAssert(QT->isPointerType() &&
243                QT->getPointeeType().isConstQualified());
244
245    const clang::ASTContext &C = Context->getASTContext();
246    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
247        C.VoidTy) {
248      // In the case of using const void*, we can't reflect an appopriate
249      // Java type, so we fall back to just reflecting the ain/aout parameters
250      FE->mUsrData = NULL;
251    } else {
252      clang::RecordDecl *RD =
253          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
254                                    Ctx.getTranslationUnitDecl(),
255                                    clang::SourceLocation(),
256                                    clang::SourceLocation(),
257                                    &Ctx.Idents.get(Id));
258
259      clang::FieldDecl *FD =
260          clang::FieldDecl::Create(Ctx,
261                                   RD,
262                                   clang::SourceLocation(),
263                                   clang::SourceLocation(),
264                                   PVD->getIdentifier(),
265                                   QT->getPointeeType(),
266                                   NULL,
267                                   /* BitWidth = */ NULL,
268                                   /* Mutable = */ false,
269                                   /* HasInit = */ false);
270      RD->addDecl(FD);
271      RD->completeDefinition();
272
273      // Create an export type iff we have a valid usrData type
274      clang::QualType T = Ctx.getTagDeclType(RD);
275      slangAssert(!T.isNull());
276
277      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
278
279      if (ET == NULL) {
280        fprintf(stderr, "Failed to export the function %s. There's at least "
281                        "one parameter whose type is not supported by the "
282                        "reflection\n", FE->getName().c_str());
283        return NULL;
284      }
285
286      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
287                  "Parameter packet must be a record");
288
289      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
290    }
291  }
292
293  if (FE->mIn) {
294    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
295    FE->mInType = RSExportType::Create(Context, T);
296  }
297
298  if (FE->mOut) {
299    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
300    FE->mOutType = RSExportType::Create(Context, T);
301  }
302
303  return FE;
304}
305
306bool RSExportForEach::isGraphicsRootRSFunc(int targetAPI,
307                                           const clang::FunctionDecl *FD) {
308  if (!isRootRSFunc(FD)) {
309    return false;
310  }
311
312  if (FD->getNumParams() == 0) {
313    // Graphics root function
314    return true;
315  }
316
317  // Check for legacy graphics root function (with single parameter).
318  if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
319    const clang::QualType &IntType = FD->getASTContext().IntTy;
320    if (FD->getResultType().getCanonicalType() == IntType) {
321      return true;
322    }
323  }
324
325  return false;
326}
327
328bool RSExportForEach::isRSForEachFunc(int targetAPI,
329    const clang::FunctionDecl *FD) {
330  if (isGraphicsRootRSFunc(targetAPI, FD)) {
331    return false;
332  }
333
334  // Check if first parameter is a pointer (which is required for ForEach).
335  unsigned int numParams = FD->getNumParams();
336
337  if (numParams > 0) {
338    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
339    clang::QualType QT = PVD->getType().getCanonicalType();
340
341    if (QT->isPointerType()) {
342      return true;
343    }
344
345    // Any non-graphics root() is automatically a ForEach candidate.
346    // At this point, however, we know that it is not going to be a valid
347    // compute root() function (due to not having a pointer parameter). We
348    // still want to return true here, so that we can issue appropriate
349    // diagnostics.
350    if (isRootRSFunc(FD)) {
351      return true;
352    }
353  }
354
355  return false;
356}
357
358bool
359RSExportForEach::validateSpecialFuncDecl(int targetAPI,
360                                         clang::DiagnosticsEngine *DiagEngine,
361                                         clang::FunctionDecl const *FD) {
362  slangAssert(DiagEngine && FD);
363  bool valid = true;
364  const clang::ASTContext &C = FD->getASTContext();
365  const clang::QualType &IntType = FD->getASTContext().IntTy;
366
367  if (isGraphicsRootRSFunc(targetAPI, FD)) {
368    if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
369      // Legacy graphics root function
370      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
371      clang::QualType QT = PVD->getType().getCanonicalType();
372      if (QT != IntType) {
373        DiagEngine->Report(
374          clang::FullSourceLoc(PVD->getLocation(),
375                               DiagEngine->getSourceManager()),
376          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
377                                      "invalid parameter type for legacy "
378                                      "graphics root() function: %0"))
379          << PVD->getType();
380        valid = false;
381      }
382    }
383
384    // Graphics root function, so verify that it returns an int
385    if (FD->getResultType().getCanonicalType() != IntType) {
386      DiagEngine->Report(
387        clang::FullSourceLoc(FD->getLocation(),
388                             DiagEngine->getSourceManager()),
389        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
390                                    "root() is required to return "
391                                    "an int for graphics usage"));
392      valid = false;
393    }
394  } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
395    if (FD->getNumParams() != 0) {
396      DiagEngine->Report(
397          clang::FullSourceLoc(FD->getLocation(),
398                               DiagEngine->getSourceManager()),
399          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
400                                      "%0(void) is required to have no "
401                                      "parameters")) << FD->getName();
402      valid = false;
403    }
404
405    if (FD->getResultType().getCanonicalType() != C.VoidTy) {
406      DiagEngine->Report(
407          clang::FullSourceLoc(FD->getLocation(),
408                               DiagEngine->getSourceManager()),
409          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
410                                      "%0(void) is required to have a void "
411                                      "return type")) << FD->getName();
412      valid = false;
413    }
414  } else {
415    slangAssert(false && "must be called on root, init or .rs.dtor function!");
416  }
417
418  return valid;
419}
420
421}  // namespace slang
422