slang_rs_export_foreach.cpp revision b5a89fbfcba6d8817c1c3700ed78bd6482cf1a5d
1/*
2 * Copyright 2011, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Decl.h"
23#include "clang/AST/TypeLoc.h"
24
25#include "llvm/DerivedTypes.h"
26#include "llvm/Target/TargetData.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31
32namespace slang {
33
34namespace {
35
36static void ReportNameError(clang::Diagnostic *Diags,
37                            const clang::ParmVarDecl *PVD) {
38  slangAssert(Diags && PVD);
39  const clang::SourceManager &SM = Diags->getSourceManager();
40
41  Diags->Report(clang::FullSourceLoc(PVD->getLocation(), SM),
42                Diags->getCustomDiagID(clang::Diagnostic::Error,
43                "Duplicate parameter entry (by position/name): '%0'"))
44       << PVD->getName();
45  return;
46}
47
48}  // namespace
49
50// This function takes care of additional validation and construction of
51// parameters related to forEach_* reflection.
52bool RSExportForEach::validateAndConstructParams(
53    RSContext *Context, const clang::FunctionDecl *FD) {
54  slangAssert(Context && FD);
55  bool valid = true;
56  clang::ASTContext &C = Context->getASTContext();
57  clang::Diagnostic *Diags = Context->getDiagnostics();
58
59  if (!isRootRSFunc(FD)) {
60    slangAssert(false && "must be called on compute root function!");
61  }
62
63  numParams = FD->getNumParams();
64  slangAssert(numParams > 0);
65
66  // Compute root functions are required to return a void type for now
67  if (FD->getResultType().getCanonicalType() != C.VoidTy) {
68    Diags->Report(
69        clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
70        Diags->getCustomDiagID(clang::Diagnostic::Error,
71                               "compute root() is required to return a "
72                               "void type"));
73    valid = false;
74  }
75
76  // Validate remaining parameter types
77  // TODO(all): Add support for LOD/face when we have them
78
79  for (size_t i = 0; i < numParams; i++) {
80    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
81    clang::QualType QT = PVD->getType().getCanonicalType();
82    llvm::StringRef ParamName = PVD->getName();
83
84    if (QT->isPointerType()) {
85      if (QT->getPointeeType().isConstQualified()) {
86        // const T1 *in
87        // const T3 *usrData
88        if (ParamName.equals("in")) {
89          if (mIn) {
90            ReportNameError(Diags, PVD);
91            valid = false;
92          } else {
93            mIn = PVD;
94          }
95        } else if (ParamName.equals("usrData")) {
96          if (mUsrData) {
97            ReportNameError(Diags, PVD);
98            valid = false;
99          } else {
100            mUsrData = PVD;
101          }
102        } else {
103          // Issue warning about positional parameter usage
104          if (!mIn) {
105            mIn = PVD;
106          } else if (!mUsrData) {
107            mUsrData = PVD;
108          } else {
109            Diags->Report(
110                clang::FullSourceLoc(PVD->getLocation(),
111                                     Diags->getSourceManager()),
112                Diags->getCustomDiagID(clang::Diagnostic::Error,
113                                       "Unexpected root() parameter '%0' "
114                                       "of type '%1'"))
115                << PVD->getName() << PVD->getType().getAsString();
116            valid = false;
117          }
118        }
119      } else {
120        // T2 *out
121        if (ParamName.equals("out")) {
122          if (mOut) {
123            ReportNameError(Diags, PVD);
124            valid = false;
125          } else {
126            mOut = PVD;
127          }
128        } else {
129          if (!mOut) {
130            mOut = PVD;
131          } else {
132            Diags->Report(
133                clang::FullSourceLoc(PVD->getLocation(),
134                                     Diags->getSourceManager()),
135                Diags->getCustomDiagID(clang::Diagnostic::Error,
136                                       "Unexpected root() parameter '%0' "
137                                       "of type '%1'"))
138                << PVD->getName() << PVD->getType().getAsString();
139            valid = false;
140          }
141        }
142      }
143    } else if (QT.getUnqualifiedType() == C.UnsignedIntTy) {
144      if (ParamName.equals("x")) {
145        if (mX) {
146          ReportNameError(Diags, PVD);
147          valid = false;
148        } else {
149          mX = PVD;
150        }
151      } else if (ParamName.equals("y")) {
152        if (mY) {
153          ReportNameError(Diags, PVD);
154          valid = false;
155        } else {
156          mY = PVD;
157        }
158      } else if (ParamName.equals("z")) {
159        if (mZ) {
160          ReportNameError(Diags, PVD);
161          valid = false;
162        } else {
163          mZ = PVD;
164        }
165      } else if (ParamName.equals("ar")) {
166        if (mAr) {
167          ReportNameError(Diags, PVD);
168          valid = false;
169        } else {
170          mAr = PVD;
171        }
172      } else {
173        if (!mX) {
174          mX = PVD;
175        } else if (!mY) {
176          mY = PVD;
177        } else if (!mZ) {
178          mZ = PVD;
179        } else if (!mAr) {
180          mAr = PVD;
181        } else {
182          Diags->Report(
183              clang::FullSourceLoc(PVD->getLocation(),
184                                   Diags->getSourceManager()),
185              Diags->getCustomDiagID(clang::Diagnostic::Error,
186                                     "Unexpected root() parameter '%0' "
187                                     "of type '%1'"))
188              << PVD->getName() << PVD->getType().getAsString();
189          valid = false;
190        }
191      }
192    } else {
193      Diags->Report(
194          clang::FullSourceLoc(
195              PVD->getTypeSourceInfo()->getTypeLoc().getBeginLoc(),
196              Diags->getSourceManager()),
197          Diags->getCustomDiagID(clang::Diagnostic::Error,
198              "Unexpected root() parameter type '%0'"))
199          << PVD->getType().getAsString();
200      valid = false;
201    }
202  }
203
204  if (!mIn && !mOut) {
205    Diags->Report(
206        clang::FullSourceLoc(FD->getLocation(),
207                             Diags->getSourceManager()),
208        Diags->getCustomDiagID(clang::Diagnostic::Error,
209                               "Compute root() must have at least one "
210                               "parameter for in or out"));
211    valid = false;
212  }
213
214  return valid;
215}
216
217RSExportForEach *RSExportForEach::Create(RSContext *Context,
218                                         const clang::FunctionDecl *FD) {
219  slangAssert(Context && FD);
220  llvm::StringRef Name = FD->getName();
221  RSExportForEach *FE;
222
223  slangAssert(!Name.empty() && "Function must have a name");
224
225  FE = new RSExportForEach(Context, Name, FD);
226
227  if (!FE->validateAndConstructParams(Context, FD)) {
228    delete FE;
229    return NULL;
230  }
231
232  clang::ASTContext &Ctx = Context->getASTContext();
233
234  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
235  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
236
237  // Extract the usrData parameter (if we have one)
238  if (FE->mUsrData) {
239    const clang::ParmVarDecl *PVD = FE->mUsrData;
240    clang::QualType QT = PVD->getType().getCanonicalType();
241    slangAssert(QT->isPointerType() &&
242                QT->getPointeeType().isConstQualified());
243
244    const clang::ASTContext &C = Context->getASTContext();
245    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
246        C.VoidTy) {
247      // In the case of using const void*, we can't reflect an appopriate
248      // Java type, so we fall back to just reflecting the ain/aout parameters
249      FE->mUsrData = NULL;
250    } else {
251      clang::RecordDecl *RD =
252          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
253                                    Ctx.getTranslationUnitDecl(),
254                                    clang::SourceLocation(),
255                                    clang::SourceLocation(),
256                                    &Ctx.Idents.get(Id));
257
258      llvm::StringRef ParamName = PVD->getName();
259      clang::FieldDecl *FD =
260          clang::FieldDecl::Create(Ctx,
261                                   RD,
262                                   clang::SourceLocation(),
263                                   clang::SourceLocation(),
264                                   PVD->getIdentifier(),
265                                   QT->getPointeeType(),
266                                   NULL,
267                                   /* BitWidth = */NULL,
268                                   /* Mutable = */false);
269      RD->addDecl(FD);
270      RD->completeDefinition();
271
272      // Create an export type iff we have a valid usrData type
273      clang::QualType T = Ctx.getTagDeclType(RD);
274      slangAssert(!T.isNull());
275
276      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
277
278      if (ET == NULL) {
279        fprintf(stderr, "Failed to export the function %s. There's at least "
280                        "one parameter whose type is not supported by the "
281                        "reflection\n", FE->getName().c_str());
282        delete FE;
283        return NULL;
284      }
285
286      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
287                  "Parameter packet must be a record");
288
289      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
290    }
291  }
292
293  if (FE->mIn) {
294    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
295    FE->mInType = RSExportType::Create(Context, T);
296  }
297
298  if (FE->mOut) {
299    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
300    FE->mOutType = RSExportType::Create(Context, T);
301  }
302
303  return FE;
304}
305
306bool RSExportForEach::isRSForEachFunc(const clang::FunctionDecl *FD) {
307  // We currently support only compute root() being exported via forEach
308  if (!isRootRSFunc(FD)) {
309    return false;
310  }
311
312  if (FD->getNumParams() == 0) {
313    // Graphics compute function
314    return false;
315  }
316  return true;
317}
318
319bool RSExportForEach::validateSpecialFuncDecl(clang::Diagnostic *Diags,
320                                              const clang::FunctionDecl *FD) {
321  slangAssert(Diags && FD);
322  bool valid = true;
323  const clang::ASTContext &C = FD->getASTContext();
324
325  if (isRootRSFunc(FD)) {
326    unsigned int numParams = FD->getNumParams();
327    if (numParams == 0) {
328      // Graphics root function, so verify that it returns an int
329      if (FD->getResultType().getCanonicalType() != C.IntTy) {
330        Diags->Report(
331            clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
332            Diags->getCustomDiagID(clang::Diagnostic::Error,
333                                   "root(void) is required to return "
334                                   "an int for graphics usage"));
335        valid = false;
336      }
337    } else {
338      slangAssert(false &&
339          "Should not call validateSpecialFuncDecl() on compute root()");
340    }
341  } else if (isInitRSFunc(FD)) {
342    if (FD->getNumParams() != 0) {
343      Diags->Report(
344          clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
345          Diags->getCustomDiagID(clang::Diagnostic::Error,
346                                 "init(void) is required to have no "
347                                 "parameters"));
348      valid = false;
349    }
350
351    if (FD->getResultType().getCanonicalType() != C.VoidTy) {
352      Diags->Report(
353          clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
354          Diags->getCustomDiagID(clang::Diagnostic::Error,
355                                 "init(void) is required to have a void "
356                                 "return type"));
357      valid = false;
358    }
359  } else {
360    slangAssert(false && "must be called on init or root function!");
361  }
362
363  return valid;
364}
365
366}  // namespace slang
367