slang_rs_export_foreach.cpp revision b69aa6557572c9ca91c46add3016962af0c993e7
1/*
2 * Copyright 2011, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Decl.h"
23#include "clang/AST/TypeLoc.h"
24
25#include "llvm/DerivedTypes.h"
26#include "llvm/Target/TargetData.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31
32namespace slang {
33
34namespace {
35
36static void ReportNameError(clang::Diagnostic *Diags,
37                            const clang::ParmVarDecl *PVD) {
38  slangAssert(Diags && PVD);
39  const clang::SourceManager &SM = Diags->getSourceManager();
40
41  Diags->Report(clang::FullSourceLoc(PVD->getLocation(), SM),
42                Diags->getCustomDiagID(clang::Diagnostic::Error,
43                "Duplicate parameter entry (by position/name): '%0'"))
44       << PVD->getName();
45  return;
46}
47
48}  // namespace
49
50// This function takes care of additional validation and construction of
51// parameters related to forEach_* reflection.
52bool RSExportForEach::validateAndConstructParams(
53    RSContext *Context, const clang::FunctionDecl *FD) {
54  slangAssert(Context && FD);
55  bool valid = true;
56  clang::ASTContext &C = Context->getASTContext();
57  clang::Diagnostic *Diags = Context->getDiagnostics();
58
59  if (!isRootRSFunc(FD)) {
60    slangAssert(false && "must be called on compute root function!");
61  }
62
63  numParams = FD->getNumParams();
64  slangAssert(numParams > 0);
65
66  // Compute root functions are required to return a void type for now
67  if (FD->getResultType().getCanonicalType() != C.VoidTy) {
68    Diags->Report(
69        clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
70        Diags->getCustomDiagID(clang::Diagnostic::Error,
71                               "compute root() is required to return a "
72                               "void type"));
73    valid = false;
74  }
75
76  // Validate remaining parameter types
77  // TODO(all): Add support for LOD/face when we have them
78
79  for (size_t i = 0; i < numParams; i++) {
80    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
81    clang::QualType QT = PVD->getType().getCanonicalType();
82    llvm::StringRef ParamName = PVD->getName();
83
84    if (QT->isPointerType()) {
85      if (QT->getPointeeType().isConstQualified()) {
86        // const T1 *in
87        // const T3 *usrData
88        if (ParamName.equals("in")) {
89          if (mIn) {
90            ReportNameError(Diags, PVD);
91            valid = false;
92          } else {
93            mIn = PVD;
94          }
95        } else if (ParamName.equals("usrData")) {
96          if (mUsrData) {
97            ReportNameError(Diags, PVD);
98            valid = false;
99          } else {
100            mUsrData = PVD;
101          }
102        } else {
103          // Issue warning about positional parameter usage
104          if (!mIn) {
105            mIn = PVD;
106          } else if (!mUsrData) {
107            mUsrData = PVD;
108          } else {
109            Diags->Report(
110                clang::FullSourceLoc(PVD->getLocation(),
111                                     Diags->getSourceManager()),
112                Diags->getCustomDiagID(clang::Diagnostic::Error,
113                                       "Unexpected root() parameter '%0' "
114                                       "of type '%1'"))
115                << PVD->getName() << PVD->getType().getAsString();
116            valid = false;
117          }
118        }
119      } else {
120        // T2 *out
121        if (ParamName.equals("out")) {
122          if (mOut) {
123            ReportNameError(Diags, PVD);
124            valid = false;
125          } else {
126            mOut = PVD;
127          }
128        } else {
129          if (!mOut) {
130            mOut = PVD;
131          } else {
132            Diags->Report(
133                clang::FullSourceLoc(PVD->getLocation(),
134                                     Diags->getSourceManager()),
135                Diags->getCustomDiagID(clang::Diagnostic::Error,
136                                       "Unexpected root() parameter '%0' "
137                                       "of type '%1'"))
138                << PVD->getName() << PVD->getType().getAsString();
139            valid = false;
140          }
141        }
142      }
143    } else if (QT.getUnqualifiedType() == C.UnsignedIntTy) {
144      if (ParamName.equals("x")) {
145        if (mX) {
146          ReportNameError(Diags, PVD);
147          valid = false;
148        } else {
149          mX = PVD;
150        }
151      } else if (ParamName.equals("y")) {
152        if (mY) {
153          ReportNameError(Diags, PVD);
154          valid = false;
155        } else {
156          mY = PVD;
157        }
158      } else if (ParamName.equals("z")) {
159        if (mZ) {
160          ReportNameError(Diags, PVD);
161          valid = false;
162        } else {
163          mZ = PVD;
164        }
165      } else if (ParamName.equals("ar")) {
166        if (mAr) {
167          ReportNameError(Diags, PVD);
168          valid = false;
169        } else {
170          mAr = PVD;
171        }
172      } else {
173        if (!mX) {
174          mX = PVD;
175        } else if (!mY) {
176          mY = PVD;
177        } else if (!mZ) {
178          mZ = PVD;
179        } else if (!mAr) {
180          mAr = PVD;
181        } else {
182          Diags->Report(
183              clang::FullSourceLoc(PVD->getLocation(),
184                                   Diags->getSourceManager()),
185              Diags->getCustomDiagID(clang::Diagnostic::Error,
186                                     "Unexpected root() parameter '%0' "
187                                     "of type '%1'"))
188              << PVD->getName() << PVD->getType().getAsString();
189          valid = false;
190        }
191      }
192    } else {
193      Diags->Report(
194          clang::FullSourceLoc(
195              PVD->getTypeSourceInfo()->getTypeLoc().getBeginLoc(),
196              Diags->getSourceManager()),
197          Diags->getCustomDiagID(clang::Diagnostic::Error,
198              "Unexpected root() parameter type '%0'"))
199          << PVD->getType().getAsString();
200      valid = false;
201    }
202  }
203
204  if (!mIn && !mOut) {
205    Diags->Report(
206        clang::FullSourceLoc(FD->getLocation(),
207                             Diags->getSourceManager()),
208        Diags->getCustomDiagID(clang::Diagnostic::Error,
209                               "Compute root() must have at least one "
210                               "parameter for in or out"));
211    valid = false;
212  }
213
214  return valid;
215}
216
217RSExportForEach *RSExportForEach::Create(RSContext *Context,
218                                         const clang::FunctionDecl *FD) {
219  slangAssert(Context && FD);
220  llvm::StringRef Name = FD->getName();
221  RSExportForEach *FE;
222
223  slangAssert(!Name.empty() && "Function must have a name");
224
225  FE = new RSExportForEach(Context, Name, FD);
226
227  if (!FE->validateAndConstructParams(Context, FD)) {
228    return NULL;
229  }
230
231  clang::ASTContext &Ctx = Context->getASTContext();
232
233  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
234  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
235
236  // Extract the usrData parameter (if we have one)
237  if (FE->mUsrData) {
238    const clang::ParmVarDecl *PVD = FE->mUsrData;
239    clang::QualType QT = PVD->getType().getCanonicalType();
240    slangAssert(QT->isPointerType() &&
241                QT->getPointeeType().isConstQualified());
242
243    const clang::ASTContext &C = Context->getASTContext();
244    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
245        C.VoidTy) {
246      // In the case of using const void*, we can't reflect an appopriate
247      // Java type, so we fall back to just reflecting the ain/aout parameters
248      FE->mUsrData = NULL;
249    } else {
250      clang::RecordDecl *RD =
251          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
252                                    Ctx.getTranslationUnitDecl(),
253                                    clang::SourceLocation(),
254                                    clang::SourceLocation(),
255                                    &Ctx.Idents.get(Id));
256
257      llvm::StringRef ParamName = PVD->getName();
258      clang::FieldDecl *FD =
259          clang::FieldDecl::Create(Ctx,
260                                   RD,
261                                   clang::SourceLocation(),
262                                   clang::SourceLocation(),
263                                   PVD->getIdentifier(),
264                                   QT->getPointeeType(),
265                                   NULL,
266                                   /* BitWidth = */ NULL,
267                                   /* Mutable = */ false,
268                                   /* HasInit = */ false);
269      RD->addDecl(FD);
270      RD->completeDefinition();
271
272      // Create an export type iff we have a valid usrData type
273      clang::QualType T = Ctx.getTagDeclType(RD);
274      slangAssert(!T.isNull());
275
276      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
277
278      if (ET == NULL) {
279        fprintf(stderr, "Failed to export the function %s. There's at least "
280                        "one parameter whose type is not supported by the "
281                        "reflection\n", FE->getName().c_str());
282        return NULL;
283      }
284
285      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
286                  "Parameter packet must be a record");
287
288      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
289    }
290  }
291
292  if (FE->mIn) {
293    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
294    FE->mInType = RSExportType::Create(Context, T);
295  }
296
297  if (FE->mOut) {
298    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
299    FE->mOutType = RSExportType::Create(Context, T);
300  }
301
302  return FE;
303}
304
305bool RSExportForEach::isRSForEachFunc(const clang::FunctionDecl *FD) {
306  // We currently support only compute root() being exported via forEach
307  if (!isRootRSFunc(FD)) {
308    return false;
309  }
310
311  if (FD->getNumParams() == 0) {
312    // Graphics compute function
313    return false;
314  }
315  return true;
316}
317
318bool RSExportForEach::validateSpecialFuncDecl(clang::Diagnostic *Diags,
319                                              const clang::FunctionDecl *FD) {
320  slangAssert(Diags && FD);
321  bool valid = true;
322  const clang::ASTContext &C = FD->getASTContext();
323
324  if (isRootRSFunc(FD)) {
325    unsigned int numParams = FD->getNumParams();
326    if (numParams == 0) {
327      // Graphics root function, so verify that it returns an int
328      if (FD->getResultType().getCanonicalType() != C.IntTy) {
329        Diags->Report(
330            clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
331            Diags->getCustomDiagID(clang::Diagnostic::Error,
332                                   "root(void) is required to return "
333                                   "an int for graphics usage"));
334        valid = false;
335      }
336    } else {
337      slangAssert(false &&
338          "Should not call validateSpecialFuncDecl() on compute root()");
339    }
340  } else if (isInitRSFunc(FD)) {
341    if (FD->getNumParams() != 0) {
342      Diags->Report(
343          clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
344          Diags->getCustomDiagID(clang::Diagnostic::Error,
345                                 "init(void) is required to have no "
346                                 "parameters"));
347      valid = false;
348    }
349
350    if (FD->getResultType().getCanonicalType() != C.VoidTy) {
351      Diags->Report(
352          clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
353          Diags->getCustomDiagID(clang::Diagnostic::Error,
354                                 "init(void) is required to have a void "
355                                 "return type"));
356      valid = false;
357    }
358  } else {
359    slangAssert(false && "must be called on init or root function!");
360  }
361
362  return valid;
363}
364
365}  // namespace slang
366