slang_rs_export_foreach.cpp revision ee4016d1247d3fbe50822de279d3da273d8aef4c
1/* 2 * Copyright 2011-2012, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "slang_rs_export_foreach.h" 18 19#include <string> 20 21#include "clang/AST/ASTContext.h" 22#include "clang/AST/Attr.h" 23#include "clang/AST/Decl.h" 24#include "clang/AST/TypeLoc.h" 25 26#include "llvm/IR/DerivedTypes.h" 27 28#include "slang_assert.h" 29#include "slang_rs_context.h" 30#include "slang_rs_export_type.h" 31#include "slang_version.h" 32 33namespace slang { 34 35// This function takes care of additional validation and construction of 36// parameters related to forEach_* reflection. 37bool RSExportForEach::validateAndConstructParams( 38 RSContext *Context, const clang::FunctionDecl *FD) { 39 slangAssert(Context && FD); 40 bool valid = true; 41 42 numParams = FD->getNumParams(); 43 44 if (Context->getTargetAPI() < SLANG_JB_TARGET_API) { 45 // Before JellyBean, we allowed only one kernel per file. It must be called "root". 46 if (!isRootRSFunc(FD)) { 47 Context->ReportError(FD->getLocation(), 48 "Non-root compute kernel %0() is " 49 "not supported in SDK levels %1-%2") 50 << FD->getName() << SLANG_MINIMUM_TARGET_API 51 << (SLANG_JB_TARGET_API - 1); 52 return false; 53 } 54 } 55 56 mResultType = FD->getReturnType().getCanonicalType(); 57 // Compute kernel functions are defined differently when the 58 // "__attribute__((kernel))" is set. 59 if (FD->hasAttr<clang::KernelAttr>()) { 60 valid |= validateAndConstructKernelParams(Context, FD); 61 } else { 62 valid |= validateAndConstructOldStyleParams(Context, FD); 63 } 64 valid |= setSignatureMetadata(Context, FD); 65 return valid; 66} 67 68bool RSExportForEach::validateAndConstructOldStyleParams( 69 RSContext *Context, const clang::FunctionDecl *FD) { 70 slangAssert(Context && FD); 71 // If numParams is 0, we already marked this as a graphics root(). 72 slangAssert(numParams > 0); 73 74 bool valid = true; 75 76 // Compute kernel functions of this style are required to return a void type. 77 clang::ASTContext &C = Context->getASTContext(); 78 if (mResultType != C.VoidTy) { 79 Context->ReportError(FD->getLocation(), 80 "Compute kernel %0() is required to return a " 81 "void type") 82 << FD->getName(); 83 valid = false; 84 } 85 86 // Validate remaining parameter types 87 // TODO(all): Add support for LOD/face when we have them 88 89 size_t IndexOfFirstIterator = numParams; 90 valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator); 91 92 // Validate the non-iterator parameters, which should all be found before the 93 // first iterator. 94 for (size_t i = 0; i < IndexOfFirstIterator; i++) { 95 const clang::ParmVarDecl *PVD = FD->getParamDecl(i); 96 clang::QualType QT = PVD->getType().getCanonicalType(); 97 98 if (!QT->isPointerType()) { 99 Context->ReportError(PVD->getLocation(), 100 "Compute kernel %0() cannot have non-pointer " 101 "parameters besides 'x' and 'y'. Parameter '%1' is " 102 "of type: '%2'") 103 << FD->getName() << PVD->getName() << PVD->getType().getAsString(); 104 valid = false; 105 continue; 106 } 107 108 // The only non-const pointer should be out. 109 if (!QT->getPointeeType().isConstQualified()) { 110 if (mOut == NULL) { 111 mOut = PVD; 112 } else { 113 Context->ReportError(PVD->getLocation(), 114 "Compute kernel %0() can only have one non-const " 115 "pointer parameter. Parameters '%1' and '%2' are " 116 "both non-const.") 117 << FD->getName() << mOut->getName() << PVD->getName(); 118 valid = false; 119 } 120 } else { 121 if (mIn == NULL && mOut == NULL) { 122 mIn = PVD; 123 } else if (mUsrData == NULL) { 124 mUsrData = PVD; 125 } else { 126 Context->ReportError( 127 PVD->getLocation(), 128 "Unexpected parameter '%0' for compute kernel %1()") 129 << PVD->getName() << FD->getName(); 130 valid = false; 131 } 132 } 133 } 134 135 if (!mIn && !mOut) { 136 Context->ReportError(FD->getLocation(), 137 "Compute kernel %0() must have at least one " 138 "parameter for in or out") 139 << FD->getName(); 140 valid = false; 141 } 142 143 return valid; 144} 145 146bool RSExportForEach::validateAndConstructKernelParams( 147 RSContext *Context, const clang::FunctionDecl *FD) { 148 slangAssert(Context && FD); 149 bool valid = true; 150 clang::ASTContext &C = Context->getASTContext(); 151 152 if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) { 153 Context->ReportError(FD->getLocation(), 154 "Compute kernel %0() targeting SDK levels " 155 "%1-%2 may not use pass-by-value with " 156 "__attribute__((kernel))") 157 << FD->getName() << SLANG_MINIMUM_TARGET_API 158 << (SLANG_JB_MR1_TARGET_API - 1); 159 return false; 160 } 161 162 // Denote that we are indeed a pass-by-value kernel. 163 mIsKernelStyle = true; 164 mHasReturnType = (mResultType != C.VoidTy); 165 166 if (mResultType->isPointerType()) { 167 Context->ReportError( 168 FD->getTypeSpecStartLoc(), 169 "Compute kernel %0() cannot return a pointer type: '%1'") 170 << FD->getName() << mResultType.getAsString(); 171 valid = false; 172 } 173 174 // Validate remaining parameter types 175 // TODO(all): Add support for LOD/face when we have them 176 177 size_t IndexOfFirstIterator = numParams; 178 valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator); 179 180 // Validate the non-iterator parameters, which should all be found before the 181 // first iterator. 182 for (size_t i = 0; i < IndexOfFirstIterator; i++) { 183 const clang::ParmVarDecl *PVD = FD->getParamDecl(i); 184 if (i == 0) { 185 mIn = PVD; 186 } else { 187 Context->ReportError(PVD->getLocation(), 188 "Unrecognized parameter '%0'. Compute kernel %1() " 189 "can only have one input parameter, 'x', and 'y'") 190 << PVD->getName() << FD->getName(); 191 valid = false; 192 } 193 clang::QualType QT = PVD->getType().getCanonicalType(); 194 if (QT->isPointerType()) { 195 Context->ReportError(PVD->getLocation(), 196 "Compute kernel %0() cannot have " 197 "parameter '%1' of pointer type: '%2'") 198 << FD->getName() << PVD->getName() << PVD->getType().getAsString(); 199 valid = false; 200 } 201 } 202 203 // Check that we have at least one allocation to use for dimensions. 204 if (valid && !mIn && !mHasReturnType) { 205 Context->ReportError(FD->getLocation(), 206 "Compute kernel %0() must have at least one " 207 "input parameter or a non-void return " 208 "type") 209 << FD->getName(); 210 valid = false; 211 } 212 213 return valid; 214} 215 216// Search for the optional x and y parameters. Returns true if valid. Also 217// sets *IndexOfFirstIterator to the index of the first iterator parameter, or 218// FD->getNumParams() if none are found. 219bool RSExportForEach::validateIterationParameters( 220 RSContext *Context, const clang::FunctionDecl *FD, 221 size_t *IndexOfFirstIterator) { 222 slangAssert(IndexOfFirstIterator != NULL); 223 slangAssert(mX == NULL && mY == NULL); 224 clang::ASTContext &C = Context->getASTContext(); 225 226 // Find the x and y parameters if present. 227 size_t NumParams = FD->getNumParams(); 228 *IndexOfFirstIterator = NumParams; 229 bool valid = true; 230 for (size_t i = 0; i < NumParams; i++) { 231 const clang::ParmVarDecl *PVD = FD->getParamDecl(i); 232 llvm::StringRef ParamName = PVD->getName(); 233 if (ParamName.equals("x")) { 234 slangAssert(mX == NULL); // We won't be invoked if two 'x' are present. 235 mX = PVD; 236 if (mY != NULL) { 237 Context->ReportError(PVD->getLocation(), 238 "In compute kernel %0(), parameter 'x' should " 239 "be defined before parameter 'y'") 240 << FD->getName(); 241 valid = false; 242 } 243 } else if (ParamName.equals("y")) { 244 slangAssert(mY == NULL); // We won't be invoked if two 'y' are present. 245 mY = PVD; 246 } else { 247 // It's neither x nor y. 248 if (*IndexOfFirstIterator < NumParams) { 249 Context->ReportError(PVD->getLocation(), 250 "In compute kernel %0(), parameter '%1' cannot " 251 "appear after the 'x' and 'y' parameters") 252 << FD->getName() << ParamName; 253 valid = false; 254 } 255 continue; 256 } 257 // Validate the data type of x and y. 258 clang::QualType QT = PVD->getType().getCanonicalType(); 259 clang::QualType UT = QT.getUnqualifiedType(); 260 if (UT != C.UnsignedIntTy && UT != C.IntTy) { 261 Context->ReportError(PVD->getLocation(), 262 "Parameter '%0' must be of type 'int' or " 263 "'unsigned int'. It is of type '%1'") 264 << ParamName << PVD->getType().getAsString(); 265 valid = false; 266 } 267 // If this is the first time we find an iterator, save it. 268 if (*IndexOfFirstIterator >= NumParams) { 269 *IndexOfFirstIterator = i; 270 } 271 } 272 // Check that x and y have the same type. 273 if (mX != NULL and mY != NULL) { 274 clang::QualType XType = mX->getType(); 275 clang::QualType YType = mY->getType(); 276 277 if (XType != YType) { 278 Context->ReportError(mY->getLocation(), 279 "Parameter 'x' and 'y' must be of the same type. " 280 "'x' is of type '%0' while 'y' is of type '%1'") 281 << XType.getAsString() << YType.getAsString(); 282 valid = false; 283 } 284 } 285 return valid; 286} 287 288bool RSExportForEach::setSignatureMetadata(RSContext *Context, 289 const clang::FunctionDecl *FD) { 290 mSignatureMetadata = 0; 291 bool valid = true; 292 293 if (mIsKernelStyle) { 294 slangAssert(mOut == NULL); 295 slangAssert(mUsrData == NULL); 296 } else { 297 slangAssert(!mHasReturnType); 298 } 299 300 // Set up the bitwise metadata encoding for runtime argument passing. 301 // TODO: If this bit field is re-used from C++ code, define the values in a header. 302 const bool HasOut = mOut || mHasReturnType; 303 mSignatureMetadata |= (mIn ? 0x01 : 0); 304 mSignatureMetadata |= (HasOut ? 0x02 : 0); 305 mSignatureMetadata |= (mUsrData ? 0x04 : 0); 306 mSignatureMetadata |= (mX ? 0x08 : 0); 307 mSignatureMetadata |= (mY ? 0x10 : 0); 308 mSignatureMetadata |= (mIsKernelStyle ? 0x20 : 0); // pass-by-value 309 310 if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) { 311 // APIs before ICS cannot skip between parameters. It is ok, however, for 312 // them to omit further parameters (i.e. skipping X is ok if you skip Y). 313 if (mSignatureMetadata != 0x1f && // In, Out, UsrData, X, Y 314 mSignatureMetadata != 0x0f && // In, Out, UsrData, X 315 mSignatureMetadata != 0x07 && // In, Out, UsrData 316 mSignatureMetadata != 0x03 && // In, Out 317 mSignatureMetadata != 0x01) { // In 318 Context->ReportError(FD->getLocation(), 319 "Compute kernel %0() targeting SDK levels " 320 "%1-%2 may not skip parameters") 321 << FD->getName() << SLANG_MINIMUM_TARGET_API 322 << (SLANG_ICS_TARGET_API - 1); 323 valid = false; 324 } 325 } 326 return valid; 327} 328 329RSExportForEach *RSExportForEach::Create(RSContext *Context, 330 const clang::FunctionDecl *FD) { 331 slangAssert(Context && FD); 332 llvm::StringRef Name = FD->getName(); 333 RSExportForEach *FE; 334 335 slangAssert(!Name.empty() && "Function must have a name"); 336 337 FE = new RSExportForEach(Context, Name); 338 339 if (!FE->validateAndConstructParams(Context, FD)) { 340 return NULL; 341 } 342 343 clang::ASTContext &Ctx = Context->getASTContext(); 344 345 std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:"); 346 Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX); 347 348 // Extract the usrData parameter (if we have one) 349 if (FE->mUsrData) { 350 const clang::ParmVarDecl *PVD = FE->mUsrData; 351 clang::QualType QT = PVD->getType().getCanonicalType(); 352 slangAssert(QT->isPointerType() && 353 QT->getPointeeType().isConstQualified()); 354 355 const clang::ASTContext &C = Context->getASTContext(); 356 if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() == 357 C.VoidTy) { 358 // In the case of using const void*, we can't reflect an appopriate 359 // Java type, so we fall back to just reflecting the ain/aout parameters 360 FE->mUsrData = NULL; 361 } else { 362 clang::RecordDecl *RD = 363 clang::RecordDecl::Create(Ctx, clang::TTK_Struct, 364 Ctx.getTranslationUnitDecl(), 365 clang::SourceLocation(), 366 clang::SourceLocation(), 367 &Ctx.Idents.get(Id)); 368 369 clang::FieldDecl *FD = 370 clang::FieldDecl::Create(Ctx, 371 RD, 372 clang::SourceLocation(), 373 clang::SourceLocation(), 374 PVD->getIdentifier(), 375 QT->getPointeeType(), 376 NULL, 377 /* BitWidth = */ NULL, 378 /* Mutable = */ false, 379 /* HasInit = */ clang::ICIS_NoInit); 380 RD->addDecl(FD); 381 RD->completeDefinition(); 382 383 // Create an export type iff we have a valid usrData type 384 clang::QualType T = Ctx.getTagDeclType(RD); 385 slangAssert(!T.isNull()); 386 387 RSExportType *ET = RSExportType::Create(Context, T.getTypePtr()); 388 389 if (ET == NULL) { 390 fprintf(stderr, "Failed to export the function %s. There's at least " 391 "one parameter whose type is not supported by the " 392 "reflection\n", FE->getName().c_str()); 393 return NULL; 394 } 395 396 slangAssert((ET->getClass() == RSExportType::ExportClassRecord) && 397 "Parameter packet must be a record"); 398 399 FE->mParamPacketType = static_cast<RSExportRecordType *>(ET); 400 } 401 } 402 403 if (FE->mIn) { 404 const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr(); 405 FE->mInType = RSExportType::Create(Context, T); 406 if (FE->mIsKernelStyle) { 407 slangAssert(FE->mInType); 408 } 409 } 410 411 if (FE->mIsKernelStyle && FE->mHasReturnType) { 412 const clang::Type *T = FE->mResultType.getTypePtr(); 413 FE->mOutType = RSExportType::Create(Context, T); 414 slangAssert(FE->mOutType); 415 } else if (FE->mOut) { 416 const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr(); 417 FE->mOutType = RSExportType::Create(Context, T); 418 } 419 420 return FE; 421} 422 423RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) { 424 slangAssert(Context); 425 llvm::StringRef Name = "root"; 426 RSExportForEach *FE = new RSExportForEach(Context, Name); 427 FE->mDummyRoot = true; 428 return FE; 429} 430 431bool RSExportForEach::isGraphicsRootRSFunc(int targetAPI, 432 const clang::FunctionDecl *FD) { 433 if (FD->hasAttr<clang::KernelAttr>()) { 434 return false; 435 } 436 437 if (!isRootRSFunc(FD)) { 438 return false; 439 } 440 441 if (FD->getNumParams() == 0) { 442 // Graphics root function 443 return true; 444 } 445 446 // Check for legacy graphics root function (with single parameter). 447 if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) { 448 const clang::QualType &IntType = FD->getASTContext().IntTy; 449 if (FD->getReturnType().getCanonicalType() == IntType) { 450 return true; 451 } 452 } 453 454 return false; 455} 456 457bool RSExportForEach::isRSForEachFunc(int targetAPI, 458 slang::RSContext* Context, 459 const clang::FunctionDecl *FD) { 460 slangAssert(Context && FD); 461 bool hasKernelAttr = FD->hasAttr<clang::KernelAttr>(); 462 463 if (FD->getStorageClass() == clang::SC_Static) { 464 if (hasKernelAttr) { 465 Context->ReportError(FD->getLocation(), 466 "Invalid use of attribute kernel with " 467 "static function declaration: %0") 468 << FD->getName(); 469 } 470 return false; 471 } 472 473 // Anything tagged as a kernel is definitely used with ForEach. 474 if (hasKernelAttr) { 475 return true; 476 } 477 478 if (isGraphicsRootRSFunc(targetAPI, FD)) { 479 return false; 480 } 481 482 // Check if first parameter is a pointer (which is required for ForEach). 483 unsigned int numParams = FD->getNumParams(); 484 485 if (numParams > 0) { 486 const clang::ParmVarDecl *PVD = FD->getParamDecl(0); 487 clang::QualType QT = PVD->getType().getCanonicalType(); 488 489 if (QT->isPointerType()) { 490 return true; 491 } 492 493 // Any non-graphics root() is automatically a ForEach candidate. 494 // At this point, however, we know that it is not going to be a valid 495 // compute root() function (due to not having a pointer parameter). We 496 // still want to return true here, so that we can issue appropriate 497 // diagnostics. 498 if (isRootRSFunc(FD)) { 499 return true; 500 } 501 } 502 503 return false; 504} 505 506bool 507RSExportForEach::validateSpecialFuncDecl(int targetAPI, 508 slang::RSContext *Context, 509 clang::FunctionDecl const *FD) { 510 slangAssert(Context && FD); 511 bool valid = true; 512 const clang::ASTContext &C = FD->getASTContext(); 513 const clang::QualType &IntType = FD->getASTContext().IntTy; 514 515 if (isGraphicsRootRSFunc(targetAPI, FD)) { 516 if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) { 517 // Legacy graphics root function 518 const clang::ParmVarDecl *PVD = FD->getParamDecl(0); 519 clang::QualType QT = PVD->getType().getCanonicalType(); 520 if (QT != IntType) { 521 Context->ReportError(PVD->getLocation(), 522 "invalid parameter type for legacy " 523 "graphics root() function: %0") 524 << PVD->getType(); 525 valid = false; 526 } 527 } 528 529 // Graphics root function, so verify that it returns an int 530 if (FD->getReturnType().getCanonicalType() != IntType) { 531 Context->ReportError(FD->getLocation(), 532 "root() is required to return " 533 "an int for graphics usage"); 534 valid = false; 535 } 536 } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) { 537 if (FD->getNumParams() != 0) { 538 Context->ReportError(FD->getLocation(), 539 "%0(void) is required to have no " 540 "parameters") 541 << FD->getName(); 542 valid = false; 543 } 544 545 if (FD->getReturnType().getCanonicalType() != C.VoidTy) { 546 Context->ReportError(FD->getLocation(), 547 "%0(void) is required to have a void " 548 "return type") 549 << FD->getName(); 550 valid = false; 551 } 552 } else { 553 slangAssert(false && "must be called on root, init or .rs.dtor function!"); 554 } 555 556 return valid; 557} 558 559} // namespace slang 560