slang_rs_export_foreach.cpp revision 9207a2e495c8363606861e4f034504ec5c153dab
1/*
2 * Copyright 2011, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Decl.h"
23#include "clang/AST/TypeLoc.h"
24
25#include "llvm/DerivedTypes.h"
26#include "llvm/Target/TargetData.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31#include "slang_version.h"
32
33namespace slang {
34
35namespace {
36
37static void ReportNameError(clang::DiagnosticsEngine *DiagEngine,
38                            clang::ParmVarDecl const *PVD) {
39  slangAssert(DiagEngine && PVD);
40  const clang::SourceManager &SM = DiagEngine->getSourceManager();
41
42  DiagEngine->Report(
43    clang::FullSourceLoc(PVD->getLocation(), SM),
44    DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
45                                "Duplicate parameter entry "
46                                "(by position/name): '%0'"))
47    << PVD->getName();
48  return;
49}
50
51}  // namespace
52
53// This function takes care of additional validation and construction of
54// parameters related to forEach_* reflection.
55bool RSExportForEach::validateAndConstructParams(
56    RSContext *Context, const clang::FunctionDecl *FD) {
57  slangAssert(Context && FD);
58  bool valid = true;
59  clang::ASTContext &C = Context->getASTContext();
60  clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
61
62  if (!isRootRSFunc(FD)) {
63    slangAssert(false && "must be called on compute root function!");
64  }
65
66  numParams = FD->getNumParams();
67  slangAssert(numParams > 0);
68
69  // Compute root functions are required to return a void type for now
70  if (FD->getResultType().getCanonicalType() != C.VoidTy) {
71    DiagEngine->Report(
72      clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
73      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
74                                  "compute root() is required to return a "
75                                  "void type"));
76    valid = false;
77  }
78
79  // Validate remaining parameter types
80  // TODO(all): Add support for LOD/face when we have them
81
82  size_t i = 0;
83  const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
84  clang::QualType QT = PVD->getType().getCanonicalType();
85
86  // Check for const T1 *in
87  if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
88    mIn = PVD;
89    i++;  // advance parameter pointer
90  }
91
92  // Check for T2 *out
93  if (i < numParams) {
94    PVD = FD->getParamDecl(i);
95    QT = PVD->getType().getCanonicalType();
96    if (QT->isPointerType() && !QT->getPointeeType().isConstQualified()) {
97      mOut = PVD;
98      i++;  // advance parameter pointer
99    }
100  }
101
102  if (!mIn && !mOut) {
103    DiagEngine->Report(
104      clang::FullSourceLoc(FD->getLocation(),
105                           DiagEngine->getSourceManager()),
106      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
107                                  "Compute root() must have at least one "
108                                  "parameter for in or out"));
109    valid = false;
110  }
111
112  // Check for T3 *usrData
113  if (i < numParams) {
114    PVD = FD->getParamDecl(i);
115    QT = PVD->getType().getCanonicalType();
116    if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
117      mUsrData = PVD;
118      i++;  // advance parameter pointer
119    }
120  }
121
122  while (i < numParams) {
123    PVD = FD->getParamDecl(i);
124    QT = PVD->getType().getCanonicalType();
125
126    if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
127      DiagEngine->Report(
128        clang::FullSourceLoc(PVD->getLocation(),
129                             DiagEngine->getSourceManager()),
130        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
131                                    "Unexpected root() parameter '%0' "
132                                    "of type '%1'"))
133        << PVD->getName() << PVD->getType().getAsString();
134      valid = false;
135    } else {
136      llvm::StringRef ParamName = PVD->getName();
137      if (ParamName.equals("x")) {
138        if (mX) {
139          ReportNameError(DiagEngine, PVD);
140          valid = false;
141        } else if (mY) {
142          // Can't go back to X after skipping Y
143          ReportNameError(DiagEngine, PVD);
144          valid = false;
145        } else {
146          mX = PVD;
147        }
148      } else if (ParamName.equals("y")) {
149        if (mY) {
150          ReportNameError(DiagEngine, PVD);
151          valid = false;
152        } else {
153          mY = PVD;
154        }
155      } else {
156        if (!mX && !mY) {
157          mX = PVD;
158        } else if (!mY) {
159          mY = PVD;
160        } else {
161          DiagEngine->Report(
162            clang::FullSourceLoc(PVD->getLocation(),
163                                 DiagEngine->getSourceManager()),
164            DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
165                                        "Unexpected root() parameter '%0' "
166                                        "of type '%1'"))
167            << PVD->getName() << PVD->getType().getAsString();
168          valid = false;
169        }
170      }
171    }
172
173    i++;
174  }
175
176  mMetadataEncoding = 0;
177  if (valid) {
178    // Set up the bitwise metadata encoding for runtime argument passing.
179    mMetadataEncoding |= (mIn ?       0x01 : 0);
180    mMetadataEncoding |= (mOut ?      0x02 : 0);
181    mMetadataEncoding |= (mUsrData ?  0x04 : 0);
182    mMetadataEncoding |= (mX ?        0x08 : 0);
183    mMetadataEncoding |= (mY ?        0x10 : 0);
184  }
185
186  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
187    // APIs before ICS cannot skip between parameters. It is ok, however, for
188    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
189    if (mMetadataEncoding != 0x1f &&  // In, Out, UsrData, X, Y
190        mMetadataEncoding != 0x0f &&  // In, Out, UsrData, X
191        mMetadataEncoding != 0x07 &&  // In, Out, UsrData
192        mMetadataEncoding != 0x03 &&  // In, Out
193        mMetadataEncoding != 0x01) {  // In
194      DiagEngine->Report(
195        clang::FullSourceLoc(FD->getLocation(),
196                             DiagEngine->getSourceManager()),
197        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
198                                    "Compute root() targeting SDK levels "
199                                    "%0-%1 may not skip parameters"))
200        << SLANG_MINIMUM_TARGET_API << (SLANG_ICS_TARGET_API-1);
201      valid = false;
202    }
203  }
204
205
206  return valid;
207}
208
209RSExportForEach *RSExportForEach::Create(RSContext *Context,
210                                         const clang::FunctionDecl *FD) {
211  slangAssert(Context && FD);
212  llvm::StringRef Name = FD->getName();
213  RSExportForEach *FE;
214
215  slangAssert(!Name.empty() && "Function must have a name");
216
217  FE = new RSExportForEach(Context, Name, FD);
218
219  if (!FE->validateAndConstructParams(Context, FD)) {
220    return NULL;
221  }
222
223  clang::ASTContext &Ctx = Context->getASTContext();
224
225  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
226  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
227
228  // Extract the usrData parameter (if we have one)
229  if (FE->mUsrData) {
230    const clang::ParmVarDecl *PVD = FE->mUsrData;
231    clang::QualType QT = PVD->getType().getCanonicalType();
232    slangAssert(QT->isPointerType() &&
233                QT->getPointeeType().isConstQualified());
234
235    const clang::ASTContext &C = Context->getASTContext();
236    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
237        C.VoidTy) {
238      // In the case of using const void*, we can't reflect an appopriate
239      // Java type, so we fall back to just reflecting the ain/aout parameters
240      FE->mUsrData = NULL;
241    } else {
242      clang::RecordDecl *RD =
243          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
244                                    Ctx.getTranslationUnitDecl(),
245                                    clang::SourceLocation(),
246                                    clang::SourceLocation(),
247                                    &Ctx.Idents.get(Id));
248
249      llvm::StringRef ParamName = PVD->getName();
250      clang::FieldDecl *FD =
251          clang::FieldDecl::Create(Ctx,
252                                   RD,
253                                   clang::SourceLocation(),
254                                   clang::SourceLocation(),
255                                   PVD->getIdentifier(),
256                                   QT->getPointeeType(),
257                                   NULL,
258                                   /* BitWidth = */ NULL,
259                                   /* Mutable = */ false,
260                                   /* HasInit = */ false);
261      RD->addDecl(FD);
262      RD->completeDefinition();
263
264      // Create an export type iff we have a valid usrData type
265      clang::QualType T = Ctx.getTagDeclType(RD);
266      slangAssert(!T.isNull());
267
268      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
269
270      if (ET == NULL) {
271        fprintf(stderr, "Failed to export the function %s. There's at least "
272                        "one parameter whose type is not supported by the "
273                        "reflection\n", FE->getName().c_str());
274        return NULL;
275      }
276
277      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
278                  "Parameter packet must be a record");
279
280      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
281    }
282  }
283
284  if (FE->mIn) {
285    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
286    FE->mInType = RSExportType::Create(Context, T);
287  }
288
289  if (FE->mOut) {
290    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
291    FE->mOutType = RSExportType::Create(Context, T);
292  }
293
294  return FE;
295}
296
297bool RSExportForEach::isRSForEachFunc(const clang::FunctionDecl *FD) {
298  // We currently support only compute root() being exported via forEach
299  if (!isRootRSFunc(FD)) {
300    return false;
301  }
302
303  if (FD->getNumParams() == 0) {
304    // Graphics compute function
305    return false;
306  }
307  return true;
308}
309
310bool
311RSExportForEach::validateSpecialFuncDecl(clang::DiagnosticsEngine *DiagEngine,
312                                         clang::FunctionDecl const *FD) {
313  slangAssert(DiagEngine && FD);
314  bool valid = true;
315  const clang::ASTContext &C = FD->getASTContext();
316
317  if (isRootRSFunc(FD)) {
318    unsigned int numParams = FD->getNumParams();
319    if (numParams == 0) {
320      // Graphics root function, so verify that it returns an int
321      if (FD->getResultType().getCanonicalType() != C.IntTy) {
322        DiagEngine->Report(
323          clang::FullSourceLoc(FD->getLocation(),
324                               DiagEngine->getSourceManager()),
325          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
326                                      "root(void) is required to return "
327                                      "an int for graphics usage"));
328        valid = false;
329      }
330    } else {
331      slangAssert(false &&
332          "Should not call validateSpecialFuncDecl() on compute root()");
333    }
334  } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
335    if (FD->getNumParams() != 0) {
336      DiagEngine->Report(
337          clang::FullSourceLoc(FD->getLocation(),
338                               DiagEngine->getSourceManager()),
339          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
340                                      "%0(void) is required to have no "
341                                      "parameters")) << FD->getName();
342      valid = false;
343    }
344
345    if (FD->getResultType().getCanonicalType() != C.VoidTy) {
346      DiagEngine->Report(
347          clang::FullSourceLoc(FD->getLocation(),
348                               DiagEngine->getSourceManager()),
349          DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
350                                      "%0(void) is required to have a void "
351                                      "return type")) << FD->getName();
352      valid = false;
353    }
354  } else {
355    slangAssert(false && "must be called on root, init or .rs.dtor function!");
356  }
357
358  return valid;
359}
360
361}  // namespace slang
362