slang_rs_export_foreach.cpp revision f736d5a12269e7e74740b130cdca98d9839b31e6
1/*
2 * Copyright 2011, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_export_foreach.h"
18
19#include <string>
20
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/Decl.h"
23#include "clang/AST/TypeLoc.h"
24
25#include "llvm/DerivedTypes.h"
26#include "llvm/Target/TargetData.h"
27
28#include "slang_assert.h"
29#include "slang_rs_context.h"
30#include "slang_rs_export_type.h"
31#include "slang_version.h"
32
33namespace slang {
34
35namespace {
36
37static void ReportNameError(clang::Diagnostic *Diags,
38                            const clang::ParmVarDecl *PVD) {
39  slangAssert(Diags && PVD);
40  const clang::SourceManager &SM = Diags->getSourceManager();
41
42  Diags->Report(clang::FullSourceLoc(PVD->getLocation(), SM),
43                Diags->getCustomDiagID(clang::Diagnostic::Error,
44                "Duplicate parameter entry (by position/name): '%0'"))
45       << PVD->getName();
46  return;
47}
48
49}  // namespace
50
51// This function takes care of additional validation and construction of
52// parameters related to forEach_* reflection.
53bool RSExportForEach::validateAndConstructParams(
54    RSContext *Context, const clang::FunctionDecl *FD) {
55  slangAssert(Context && FD);
56  bool valid = true;
57  clang::ASTContext &C = Context->getASTContext();
58  clang::Diagnostic *Diags = Context->getDiagnostics();
59
60  if (!isRootRSFunc(FD)) {
61    slangAssert(false && "must be called on compute root function!");
62  }
63
64  numParams = FD->getNumParams();
65  slangAssert(numParams > 0);
66
67  // Compute root functions are required to return a void type for now
68  if (FD->getResultType().getCanonicalType() != C.VoidTy) {
69    Diags->Report(
70        clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
71        Diags->getCustomDiagID(clang::Diagnostic::Error,
72                               "compute root() is required to return a "
73                               "void type"));
74    valid = false;
75  }
76
77  // Validate remaining parameter types
78  // TODO(all): Add support for LOD/face when we have them
79
80  size_t i = 0;
81  const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
82  clang::QualType QT = PVD->getType().getCanonicalType();
83
84  // Check for const T1 *in
85  if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
86    mIn = PVD;
87    i++;  // advance parameter pointer
88  }
89
90  // Check for T2 *out
91  if (i < numParams) {
92    PVD = FD->getParamDecl(i);
93    QT = PVD->getType().getCanonicalType();
94    if (QT->isPointerType() && !QT->getPointeeType().isConstQualified()) {
95      mOut = PVD;
96      i++;  // advance parameter pointer
97    }
98  }
99
100  if (!mIn && !mOut) {
101    Diags->Report(
102        clang::FullSourceLoc(FD->getLocation(),
103                             Diags->getSourceManager()),
104        Diags->getCustomDiagID(clang::Diagnostic::Error,
105                               "Compute root() must have at least one "
106                               "parameter for in or out"));
107    valid = false;
108  }
109
110  // Check for T3 *usrData
111  if (i < numParams) {
112    PVD = FD->getParamDecl(i);
113    QT = PVD->getType().getCanonicalType();
114    if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
115      mUsrData = PVD;
116      i++;  // advance parameter pointer
117    }
118  }
119
120  while (i < numParams) {
121    PVD = FD->getParamDecl(i);
122    QT = PVD->getType().getCanonicalType();
123
124    if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
125      Diags->Report(
126          clang::FullSourceLoc(PVD->getLocation(),
127                               Diags->getSourceManager()),
128          Diags->getCustomDiagID(clang::Diagnostic::Error,
129                                 "Unexpected root() parameter '%0' "
130                                 "of type '%1'"))
131          << PVD->getName() << PVD->getType().getAsString();
132      valid = false;
133    } else {
134      llvm::StringRef ParamName = PVD->getName();
135      if (ParamName.equals("x")) {
136        if (mX) {
137          ReportNameError(Diags, PVD);
138          valid = false;
139        } else if (mY) {
140          // Can't go back to X after skipping Y
141          ReportNameError(Diags, PVD);
142          valid = false;
143        } else {
144          mX = PVD;
145        }
146      } else if (ParamName.equals("y")) {
147        if (mY) {
148          ReportNameError(Diags, PVD);
149          valid = false;
150        } else {
151          mY = PVD;
152        }
153      } else {
154        if (!mX && !mY) {
155          mX = PVD;
156        } else if (!mY) {
157          mY = PVD;
158        } else {
159          Diags->Report(
160              clang::FullSourceLoc(PVD->getLocation(),
161                                   Diags->getSourceManager()),
162              Diags->getCustomDiagID(clang::Diagnostic::Error,
163                                     "Unexpected root() parameter '%0' "
164                                     "of type '%1'"))
165              << PVD->getName() << PVD->getType().getAsString();
166          valid = false;
167        }
168      }
169    }
170
171    i++;
172  }
173
174  mMetadataEncoding = 0;
175  if (valid) {
176    // Set up the bitwise metadata encoding for runtime argument passing.
177    mMetadataEncoding |= (mIn ?       0x01 : 0);
178    mMetadataEncoding |= (mOut ?      0x02 : 0);
179    mMetadataEncoding |= (mUsrData ?  0x04 : 0);
180    mMetadataEncoding |= (mX ?        0x08 : 0);
181    mMetadataEncoding |= (mY ?        0x10 : 0);
182  }
183
184  if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
185    // APIs before ICS cannot skip between parameters. It is ok, however, for
186    // them to omit further parameters (i.e. skipping X is ok if you skip Y).
187    if (mMetadataEncoding != 0x1f &&  // In, Out, UsrData, X, Y
188        mMetadataEncoding != 0x0f &&  // In, Out, UsrData, X
189        mMetadataEncoding != 0x07 &&  // In, Out, UsrData
190        mMetadataEncoding != 0x03 &&  // In, Out
191        mMetadataEncoding != 0x01) {  // In
192      Diags->Report(
193          clang::FullSourceLoc(FD->getLocation(),
194                               Diags->getSourceManager()),
195          Diags->getCustomDiagID(clang::Diagnostic::Error,
196                                 "Compute root() targeting SDK levels %0-%1 "
197                                 "may not skip parameters"))
198          << SLANG_MINIMUM_TARGET_API << (SLANG_ICS_TARGET_API-1);
199      valid = false;
200    }
201  }
202
203
204  return valid;
205}
206
207RSExportForEach *RSExportForEach::Create(RSContext *Context,
208                                         const clang::FunctionDecl *FD) {
209  slangAssert(Context && FD);
210  llvm::StringRef Name = FD->getName();
211  RSExportForEach *FE;
212
213  slangAssert(!Name.empty() && "Function must have a name");
214
215  FE = new RSExportForEach(Context, Name, FD);
216
217  if (!FE->validateAndConstructParams(Context, FD)) {
218    return NULL;
219  }
220
221  clang::ASTContext &Ctx = Context->getASTContext();
222
223  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
224  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
225
226  // Extract the usrData parameter (if we have one)
227  if (FE->mUsrData) {
228    const clang::ParmVarDecl *PVD = FE->mUsrData;
229    clang::QualType QT = PVD->getType().getCanonicalType();
230    slangAssert(QT->isPointerType() &&
231                QT->getPointeeType().isConstQualified());
232
233    const clang::ASTContext &C = Context->getASTContext();
234    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
235        C.VoidTy) {
236      // In the case of using const void*, we can't reflect an appopriate
237      // Java type, so we fall back to just reflecting the ain/aout parameters
238      FE->mUsrData = NULL;
239    } else {
240      clang::RecordDecl *RD =
241          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
242                                    Ctx.getTranslationUnitDecl(),
243                                    clang::SourceLocation(),
244                                    clang::SourceLocation(),
245                                    &Ctx.Idents.get(Id));
246
247      llvm::StringRef ParamName = PVD->getName();
248      clang::FieldDecl *FD =
249          clang::FieldDecl::Create(Ctx,
250                                   RD,
251                                   clang::SourceLocation(),
252                                   clang::SourceLocation(),
253                                   PVD->getIdentifier(),
254                                   QT->getPointeeType(),
255                                   NULL,
256                                   /* BitWidth = */ NULL,
257                                   /* Mutable = */ false,
258                                   /* HasInit = */ false);
259      RD->addDecl(FD);
260      RD->completeDefinition();
261
262      // Create an export type iff we have a valid usrData type
263      clang::QualType T = Ctx.getTagDeclType(RD);
264      slangAssert(!T.isNull());
265
266      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
267
268      if (ET == NULL) {
269        fprintf(stderr, "Failed to export the function %s. There's at least "
270                        "one parameter whose type is not supported by the "
271                        "reflection\n", FE->getName().c_str());
272        return NULL;
273      }
274
275      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
276                  "Parameter packet must be a record");
277
278      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
279    }
280  }
281
282  if (FE->mIn) {
283    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
284    FE->mInType = RSExportType::Create(Context, T);
285  }
286
287  if (FE->mOut) {
288    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
289    FE->mOutType = RSExportType::Create(Context, T);
290  }
291
292  return FE;
293}
294
295bool RSExportForEach::isRSForEachFunc(int targetAPI,
296    const clang::FunctionDecl *FD) {
297  // We currently support only compute root() being exported via forEach
298  if (!isRootRSFunc(FD)) {
299    return false;
300  }
301
302  if (FD->getNumParams() == 0) {
303    // Graphics compute function
304    return false;
305  }
306
307  // Handle legacy graphics root functions.
308  if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
309    const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
310    clang::QualType QT = PVD->getType().getCanonicalType();
311    const clang::QualType &IntType = FD->getASTContext().IntTy;
312    if ((FD->getResultType().getCanonicalType() == IntType) &&
313        (QT == IntType)) {
314      return false;
315    }
316  }
317
318  return true;
319}
320
321bool RSExportForEach::validateSpecialFuncDecl(int targetAPI,
322                                              clang::Diagnostic *Diags,
323                                              const clang::FunctionDecl *FD) {
324  slangAssert(Diags && FD);
325  bool valid = true;
326  const clang::ASTContext &C = FD->getASTContext();
327
328  if (isRootRSFunc(FD)) {
329    unsigned int numParams = FD->getNumParams();
330    if (numParams == 0) {
331      // Graphics root function, so verify that it returns an int
332      if (FD->getResultType().getCanonicalType() != C.IntTy) {
333        Diags->Report(
334            clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
335            Diags->getCustomDiagID(clang::Diagnostic::Error,
336                                   "root(void) is required to return "
337                                   "an int for graphics usage"));
338        valid = false;
339      }
340    } else if ((targetAPI < SLANG_ICS_TARGET_API) && (numParams == 1)) {
341      // Legacy graphics root function
342      // This has already been validated in isRSForEachFunc().
343    } else {
344      slangAssert(false &&
345          "Should not call validateSpecialFuncDecl() on compute root()");
346    }
347  } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
348    if (FD->getNumParams() != 0) {
349      Diags->Report(
350          clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
351          Diags->getCustomDiagID(clang::Diagnostic::Error,
352                                 "%0(void) is required to have no "
353                                 "parameters")) << FD->getName();
354      valid = false;
355    }
356
357    if (FD->getResultType().getCanonicalType() != C.VoidTy) {
358      Diags->Report(
359          clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
360          Diags->getCustomDiagID(clang::Diagnostic::Error,
361                                 "%0(void) is required to have a void "
362                                 "return type")) << FD->getName();
363      valid = false;
364    }
365  } else {
366    slangAssert(false && "must be called on root, init or .rs.dtor function!");
367  }
368
369  return valid;
370}
371
372}  // namespace slang
373