slang_rs_export_foreach.cpp revision 688e64b2d56e4218c680b9d6523c5de672f55757
1/*
2 * Copyright 2011, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Decl.h"
23#include "clang/AST/TypeLoc.h"
24
25#include "llvm/DerivedTypes.h"
26#include "llvm/Target/TargetData.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31
32namespace slang {
33
34namespace {
35
36static void ReportNameError(clang::Diagnostic *Diags,
37                            const clang::ParmVarDecl *PVD) {
38  slangAssert(Diags && PVD);
39  const clang::SourceManager &SM = Diags->getSourceManager();
40
41  Diags->Report(clang::FullSourceLoc(PVD->getLocation(), SM),
42                Diags->getCustomDiagID(clang::Diagnostic::Error,
43                "Duplicate parameter entry (by position/name): '%0'"))
44       << PVD->getName();
45  return;
46}
47
48}  // namespace
49
50// This function takes care of additional validation and construction of
51// parameters related to forEach_* reflection.
52bool RSExportForEach::validateAndConstructParams(
53    RSContext *Context, const clang::FunctionDecl *FD) {
54  slangAssert(Context && FD);
55  bool valid = true;
56  clang::ASTContext &C = Context->getASTContext();
57  clang::Diagnostic *Diags = Context->getDiagnostics();
58
59  if (!isRootRSFunc(FD)) {
60    slangAssert(false && "must be called on compute root function!");
61  }
62
63  numParams = FD->getNumParams();
64  slangAssert(numParams > 0);
65
66  // Compute root functions are required to return a void type for now
67  if (FD->getResultType().getCanonicalType() != C.VoidTy) {
68    Diags->Report(
69        clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
70        Diags->getCustomDiagID(clang::Diagnostic::Error,
71                               "compute root() is required to return a "
72                               "void type"));
73    valid = false;
74  }
75
76  // Validate remaining parameter types
77  // TODO(all): Add support for LOD/face when we have them
78
79  size_t i = 0;
80  const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
81  clang::QualType QT = PVD->getType().getCanonicalType();
82
83  // Check for const T1 *in
84  if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
85    mIn = PVD;
86    i++;  // advance parameter pointer
87  }
88
89  // Check for T2 *out
90  if (i < numParams) {
91    PVD = FD->getParamDecl(i);
92    QT = PVD->getType().getCanonicalType();
93    if (QT->isPointerType() && !QT->getPointeeType().isConstQualified()) {
94      mOut = PVD;
95      i++;  // advance parameter pointer
96    }
97  }
98
99  if (!mIn && !mOut) {
100    Diags->Report(
101        clang::FullSourceLoc(FD->getLocation(),
102                             Diags->getSourceManager()),
103        Diags->getCustomDiagID(clang::Diagnostic::Error,
104                               "Compute root() must have at least one "
105                               "parameter for in or out"));
106    valid = false;
107  }
108
109  // Check for T3 *usrData
110  if (i < numParams) {
111    PVD = FD->getParamDecl(i);
112    QT = PVD->getType().getCanonicalType();
113    if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
114      mUsrData = PVD;
115      i++;  // advance parameter pointer
116    }
117  }
118
119  while (i < numParams) {
120    PVD = FD->getParamDecl(i);
121    QT = PVD->getType().getCanonicalType();
122
123    if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
124      Diags->Report(
125          clang::FullSourceLoc(PVD->getLocation(),
126                               Diags->getSourceManager()),
127          Diags->getCustomDiagID(clang::Diagnostic::Error,
128                                 "Unexpected root() parameter '%0' "
129                                 "of type '%1'"))
130          << PVD->getName() << PVD->getType().getAsString();
131      valid = false;
132    } else {
133      llvm::StringRef ParamName = PVD->getName();
134      if (ParamName.equals("x")) {
135        if (mX) {
136          ReportNameError(Diags, PVD);
137          valid = false;
138        } else if (mY) {
139          // Can't go back to X after skipping Y
140          ReportNameError(Diags, PVD);
141          valid = false;
142        } else {
143          mX = PVD;
144        }
145      } else if (ParamName.equals("y")) {
146        if (mY) {
147          ReportNameError(Diags, PVD);
148          valid = false;
149        } else {
150          mY = PVD;
151        }
152      } else {
153        if (!mX && !mY) {
154          mX = PVD;
155        } else if (!mY) {
156          mY = PVD;
157        } else {
158          Diags->Report(
159              clang::FullSourceLoc(PVD->getLocation(),
160                                   Diags->getSourceManager()),
161              Diags->getCustomDiagID(clang::Diagnostic::Error,
162                                     "Unexpected root() parameter '%0' "
163                                     "of type '%1'"))
164              << PVD->getName() << PVD->getType().getAsString();
165          valid = false;
166        }
167      }
168    }
169
170    i++;
171  }
172
173  mMetadataEncoding = 0;
174  if (valid) {
175    // Set up the bitwise metadata encoding for runtime argument passing.
176    mMetadataEncoding |= (mIn ?       0x01 : 0);
177    mMetadataEncoding |= (mOut ?      0x02 : 0);
178    mMetadataEncoding |= (mUsrData ?  0x04 : 0);
179    mMetadataEncoding |= (mX ?        0x08 : 0);
180    mMetadataEncoding |= (mY ?        0x10 : 0);
181  }
182
183  return valid;
184}
185
186RSExportForEach *RSExportForEach::Create(RSContext *Context,
187                                         const clang::FunctionDecl *FD) {
188  slangAssert(Context && FD);
189  llvm::StringRef Name = FD->getName();
190  RSExportForEach *FE;
191
192  slangAssert(!Name.empty() && "Function must have a name");
193
194  FE = new RSExportForEach(Context, Name, FD);
195
196  if (!FE->validateAndConstructParams(Context, FD)) {
197    return NULL;
198  }
199
200  clang::ASTContext &Ctx = Context->getASTContext();
201
202  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
203  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
204
205  // Extract the usrData parameter (if we have one)
206  if (FE->mUsrData) {
207    const clang::ParmVarDecl *PVD = FE->mUsrData;
208    clang::QualType QT = PVD->getType().getCanonicalType();
209    slangAssert(QT->isPointerType() &&
210                QT->getPointeeType().isConstQualified());
211
212    const clang::ASTContext &C = Context->getASTContext();
213    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
214        C.VoidTy) {
215      // In the case of using const void*, we can't reflect an appopriate
216      // Java type, so we fall back to just reflecting the ain/aout parameters
217      FE->mUsrData = NULL;
218    } else {
219      clang::RecordDecl *RD =
220          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
221                                    Ctx.getTranslationUnitDecl(),
222                                    clang::SourceLocation(),
223                                    clang::SourceLocation(),
224                                    &Ctx.Idents.get(Id));
225
226      llvm::StringRef ParamName = PVD->getName();
227      clang::FieldDecl *FD =
228          clang::FieldDecl::Create(Ctx,
229                                   RD,
230                                   clang::SourceLocation(),
231                                   clang::SourceLocation(),
232                                   PVD->getIdentifier(),
233                                   QT->getPointeeType(),
234                                   NULL,
235                                   /* BitWidth = */ NULL,
236                                   /* Mutable = */ false,
237                                   /* HasInit = */ false);
238      RD->addDecl(FD);
239      RD->completeDefinition();
240
241      // Create an export type iff we have a valid usrData type
242      clang::QualType T = Ctx.getTagDeclType(RD);
243      slangAssert(!T.isNull());
244
245      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
246
247      if (ET == NULL) {
248        fprintf(stderr, "Failed to export the function %s. There's at least "
249                        "one parameter whose type is not supported by the "
250                        "reflection\n", FE->getName().c_str());
251        return NULL;
252      }
253
254      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
255                  "Parameter packet must be a record");
256
257      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
258    }
259  }
260
261  if (FE->mIn) {
262    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
263    FE->mInType = RSExportType::Create(Context, T);
264  }
265
266  if (FE->mOut) {
267    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
268    FE->mOutType = RSExportType::Create(Context, T);
269  }
270
271  return FE;
272}
273
274bool RSExportForEach::isRSForEachFunc(const clang::FunctionDecl *FD) {
275  // We currently support only compute root() being exported via forEach
276  if (!isRootRSFunc(FD)) {
277    return false;
278  }
279
280  if (FD->getNumParams() == 0) {
281    // Graphics compute function
282    return false;
283  }
284  return true;
285}
286
287bool RSExportForEach::validateSpecialFuncDecl(clang::Diagnostic *Diags,
288                                              const clang::FunctionDecl *FD) {
289  slangAssert(Diags && FD);
290  bool valid = true;
291  const clang::ASTContext &C = FD->getASTContext();
292
293  if (isRootRSFunc(FD)) {
294    unsigned int numParams = FD->getNumParams();
295    if (numParams == 0) {
296      // Graphics root function, so verify that it returns an int
297      if (FD->getResultType().getCanonicalType() != C.IntTy) {
298        Diags->Report(
299            clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
300            Diags->getCustomDiagID(clang::Diagnostic::Error,
301                                   "root(void) is required to return "
302                                   "an int for graphics usage"));
303        valid = false;
304      }
305    } else {
306      slangAssert(false &&
307          "Should not call validateSpecialFuncDecl() on compute root()");
308    }
309  } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
310    if (FD->getNumParams() != 0) {
311      Diags->Report(
312          clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
313          Diags->getCustomDiagID(clang::Diagnostic::Error,
314                                 "%0(void) is required to have no "
315                                 "parameters")) << FD->getName();
316      valid = false;
317    }
318
319    if (FD->getResultType().getCanonicalType() != C.VoidTy) {
320      Diags->Report(
321          clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
322          Diags->getCustomDiagID(clang::Diagnostic::Error,
323                                 "%0(void) is required to have a void "
324                                 "return type")) << FD->getName();
325      valid = false;
326    }
327  } else {
328    slangAssert(false && "must be called on root, init or .rs.dtor function!");
329  }
330
331  return valid;
332}
333
334}  // namespace slang
335