1/* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include <iomanip> 18#include <iostream> 19#include <cmath> 20#include <sstream> 21 22#include "Generator.h" 23#include "Specification.h" 24#include "Utilities.h" 25 26using namespace std; 27 28// Converts float2 to FLOAT_32 and 2, etc. 29static void convertToRsType(const string& name, string* dataType, char* vectorSize) { 30 string s = name; 31 int last = s.size() - 1; 32 char lastChar = s[last]; 33 if (lastChar >= '1' && lastChar <= '4') { 34 s.erase(last); 35 *vectorSize = lastChar; 36 } else { 37 *vectorSize = '1'; 38 } 39 dataType->clear(); 40 for (int i = 0; i < NUM_TYPES; i++) { 41 if (s == TYPES[i].cType) { 42 *dataType = TYPES[i].rsDataType; 43 break; 44 } 45 } 46} 47 48// Returns true if any permutation of the function have tests to b 49static bool needTestFiles(const Function& function, unsigned int versionOfTestFiles) { 50 for (auto spec : function.getSpecifications()) { 51 if (spec->hasTests(versionOfTestFiles)) { 52 return true; 53 } 54 } 55 return false; 56} 57 58/* One instance of this class is generated for each permutation of a function for which 59 * we are generating test code. This instance will generate both the script and the Java 60 * section of the test files for this permutation. The class is mostly used to keep track 61 * of the various names shared between script and Java files. 62 * WARNING: Because the constructor keeps a reference to the FunctionPermutation, PermutationWriter 63 * should not exceed the lifetime of FunctionPermutation. 64 */ 65class PermutationWriter { 66private: 67 FunctionPermutation& mPermutation; 68 69 string mRsKernelName; 70 string mJavaArgumentsClassName; 71 string mJavaArgumentsNClassName; 72 string mJavaVerifierComputeMethodName; 73 string mJavaVerifierVerifyMethodName; 74 string mJavaCheckMethodName; 75 string mJavaVerifyMethodName; 76 77 // Pointer to the files we are generating. Handy to avoid always passing them in the calls. 78 GeneratedFile* mRs; 79 GeneratedFile* mJava; 80 81 /* Shortcuts to the return parameter and the first input parameter of the function 82 * specification. 83 */ 84 const ParameterDefinition* mReturnParam; // Can be nullptr. NOT OWNED. 85 const ParameterDefinition* mFirstInputParam; // Can be nullptr. NOT OWNED. 86 87 /* All the parameters plus the return param, if present. Collecting them together 88 * simplifies code generation. NOT OWNED. 89 */ 90 vector<const ParameterDefinition*> mAllInputsAndOutputs; 91 92 /* We use a class to pass the arguments between the generated code and the CoreVerifier. This 93 * method generates this class. The set keeps track if we've generated this class already 94 * for this test file, as more than one permutation may use the same argument class. 95 */ 96 void writeJavaArgumentClass(bool scalar, set<string>* javaGeneratedArgumentClasses) const; 97 98 // Generate the Check* method that invokes the script and calls the verifier. 99 void writeJavaCheckMethod(bool generateCallToVerifier) const; 100 101 // Generate code to define and randomly initialize the input allocation. 102 void writeJavaInputAllocationDefinition(const ParameterDefinition& param) const; 103 104 /* Generate code that instantiate an allocation of floats or integers and fills it with 105 * random data. This random data must be compatible with the specified type. This is 106 * used for the convert_* tests, as converting values that don't fit yield undefined results. 107 */ 108 void writeJavaRandomCompatibleFloatAllocation(const string& dataType, const string& seed, 109 char vectorSize, 110 const NumericalType& compatibleType, 111 const NumericalType& generatedType) const; 112 void writeJavaRandomCompatibleIntegerAllocation(const string& dataType, const string& seed, 113 char vectorSize, 114 const NumericalType& compatibleType, 115 const NumericalType& generatedType) const; 116 117 // Generate code that defines an output allocation. 118 void writeJavaOutputAllocationDefinition(const ParameterDefinition& param) const; 119 120 /* Generate the code that verifies the results for RenderScript functions where each entry 121 * of a vector is evaluated independently. If verifierValidates is true, CoreMathVerifier 122 * does the actual validation instead of more commonly returning the range of acceptable values. 123 */ 124 void writeJavaVerifyScalarMethod(bool verifierValidates) const; 125 126 /* Generate the code that verify the results for a RenderScript function where a vector 127 * is a point in n-dimensional space. 128 */ 129 void writeJavaVerifyVectorMethod() const; 130 131 // Generate the line that creates the Target. 132 void writeJavaCreateTarget() const; 133 134 // Generate the method header of the verify function. 135 void writeJavaVerifyMethodHeader() const; 136 137 // Generate codes that copies the content of an allocation to an array. 138 void writeJavaArrayInitialization(const ParameterDefinition& p) const; 139 140 // Generate code that tests one value returned from the script. 141 void writeJavaTestAndSetValid(const ParameterDefinition& p, const string& argsIndex, 142 const string& actualIndex) const; 143 void writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex, 144 const string& actualIndex) const; 145 // For test:vector cases, generate code that compares returned vector vs. expected value. 146 void writeJavaVectorComparison(const ParameterDefinition& p) const; 147 148 // Muliple functions that generates code to build the error message if an error is found. 149 void writeJavaAppendOutputToMessage(const ParameterDefinition& p, const string& argsIndex, 150 const string& actualIndex, bool verifierValidates) const; 151 void writeJavaAppendInputToMessage(const ParameterDefinition& p, const string& actual) const; 152 void writeJavaAppendNewLineToMessage() const; 153 void writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const; 154 void writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const; 155 156 // Generate the set of instructions to call the script. 157 void writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const; 158 159 // Write an allocation definition if not already emitted in the .rs file. 160 void writeRsAllocationDefinition(const ParameterDefinition& param, 161 set<string>* rsAllocationsGenerated) const; 162 163public: 164 /* NOTE: We keep pointers to the permutation and the files. This object should not 165 * outlive the arguments. 166 */ 167 PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile, 168 GeneratedFile* javaFile); 169 string getJavaCheckMethodName() const { return mJavaCheckMethodName; } 170 171 // Write the script test function for this permutation. 172 void writeRsSection(set<string>* rsAllocationsGenerated) const; 173 // Write the section of the Java code that calls the script and validates the results 174 void writeJavaSection(set<string>* javaGeneratedArgumentClasses) const; 175}; 176 177PermutationWriter::PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile, 178 GeneratedFile* javaFile) 179 : mPermutation(permutation), 180 mRs(rsFile), 181 mJava(javaFile), 182 mReturnParam(nullptr), 183 mFirstInputParam(nullptr) { 184 mRsKernelName = "test" + capitalize(permutation.getName()); 185 186 mJavaArgumentsClassName = "Arguments"; 187 mJavaArgumentsNClassName = "Arguments"; 188 const string trunk = capitalize(permutation.getNameTrunk()); 189 mJavaCheckMethodName = "check" + trunk; 190 mJavaVerifyMethodName = "verifyResults" + trunk; 191 192 for (auto p : permutation.getParams()) { 193 mAllInputsAndOutputs.push_back(p); 194 if (mFirstInputParam == nullptr && !p->isOutParameter) { 195 mFirstInputParam = p; 196 } 197 } 198 mReturnParam = permutation.getReturn(); 199 if (mReturnParam) { 200 mAllInputsAndOutputs.push_back(mReturnParam); 201 } 202 203 for (auto p : mAllInputsAndOutputs) { 204 const string capitalizedRsType = capitalize(p->rsType); 205 const string capitalizedBaseType = capitalize(p->rsBaseType); 206 mRsKernelName += capitalizedRsType; 207 mJavaArgumentsClassName += capitalizedBaseType; 208 mJavaArgumentsNClassName += capitalizedBaseType; 209 if (p->mVectorSize != "1") { 210 mJavaArgumentsNClassName += "N"; 211 } 212 mJavaCheckMethodName += capitalizedRsType; 213 mJavaVerifyMethodName += capitalizedRsType; 214 } 215 mJavaVerifierComputeMethodName = "compute" + trunk; 216 mJavaVerifierVerifyMethodName = "verify" + trunk; 217} 218 219void PermutationWriter::writeJavaSection(set<string>* javaGeneratedArgumentClasses) const { 220 // By default, we test the results using item by item comparison. 221 const string test = mPermutation.getTest(); 222 if (test == "scalar" || test == "limited") { 223 writeJavaArgumentClass(true, javaGeneratedArgumentClasses); 224 writeJavaCheckMethod(true); 225 writeJavaVerifyScalarMethod(false); 226 } else if (test == "custom") { 227 writeJavaArgumentClass(true, javaGeneratedArgumentClasses); 228 writeJavaCheckMethod(true); 229 writeJavaVerifyScalarMethod(true); 230 } else if (test == "vector") { 231 writeJavaArgumentClass(false, javaGeneratedArgumentClasses); 232 writeJavaCheckMethod(true); 233 writeJavaVerifyVectorMethod(); 234 } else if (test == "noverify") { 235 writeJavaCheckMethod(false); 236 } 237} 238 239void PermutationWriter::writeJavaArgumentClass(bool scalar, 240 set<string>* javaGeneratedArgumentClasses) const { 241 string name; 242 if (scalar) { 243 name = mJavaArgumentsClassName; 244 } else { 245 name = mJavaArgumentsNClassName; 246 } 247 248 // Make sure we have not generated the argument class already. 249 if (!testAndSet(name, javaGeneratedArgumentClasses)) { 250 mJava->indent() << "public class " << name; 251 mJava->startBlock(); 252 253 for (auto p : mAllInputsAndOutputs) { 254 bool isFieldArray = !scalar && p->mVectorSize != "1"; 255 bool isFloatyField = p->isOutParameter && p->isFloatType && mPermutation.getTest() != "custom"; 256 257 mJava->indent() << "public "; 258 if (isFloatyField) { 259 *mJava << "Target.Floaty"; 260 } else { 261 *mJava << p->javaBaseType; 262 } 263 if (isFieldArray) { 264 *mJava << "[]"; 265 } 266 *mJava << " " << p->variableName << ";\n"; 267 268 // For Float16 parameters, add an extra 'double' field in the class 269 // to hold the Double value converted from the input. 270 if (p->isFloat16Parameter() && !isFloatyField) { 271 mJava->indent() << "public double"; 272 if (isFieldArray) { 273 *mJava << "[]"; 274 } 275 *mJava << " " + p->variableName << "Double;\n"; 276 } 277 } 278 mJava->endBlock(); 279 *mJava << "\n"; 280 } 281} 282 283void PermutationWriter::writeJavaCheckMethod(bool generateCallToVerifier) const { 284 mJava->indent() << "private void " << mJavaCheckMethodName << "()"; 285 mJava->startBlock(); 286 287 // Generate the input allocations and initialization. 288 for (auto p : mAllInputsAndOutputs) { 289 if (!p->isOutParameter) { 290 writeJavaInputAllocationDefinition(*p); 291 } 292 } 293 // Generate code to enforce ordering between two allocations if needed. 294 for (auto p : mAllInputsAndOutputs) { 295 if (!p->isOutParameter && !p->smallerParameter.empty()) { 296 string smallerAlloc = "in" + capitalize(p->smallerParameter); 297 mJava->indent() << "enforceOrdering(" << smallerAlloc << ", " << p->javaAllocName 298 << ");\n"; 299 } 300 } 301 302 // Generate code to check the full and relaxed scripts. 303 writeJavaCallToRs(false, generateCallToVerifier); 304 writeJavaCallToRs(true, generateCallToVerifier); 305 306 mJava->endBlock(); 307 *mJava << "\n"; 308} 309 310void PermutationWriter::writeJavaInputAllocationDefinition(const ParameterDefinition& param) const { 311 string dataType; 312 char vectorSize; 313 convertToRsType(param.rsType, &dataType, &vectorSize); 314 315 const string seed = hashString(mJavaCheckMethodName + param.javaAllocName); 316 mJava->indent() << "Allocation " << param.javaAllocName << " = "; 317 if (param.compatibleTypeIndex >= 0) { 318 if (TYPES[param.typeIndex].kind == FLOATING_POINT) { 319 writeJavaRandomCompatibleFloatAllocation(dataType, seed, vectorSize, 320 TYPES[param.compatibleTypeIndex], 321 TYPES[param.typeIndex]); 322 } else { 323 writeJavaRandomCompatibleIntegerAllocation(dataType, seed, vectorSize, 324 TYPES[param.compatibleTypeIndex], 325 TYPES[param.typeIndex]); 326 } 327 } else if (!param.minValue.empty()) { 328 *mJava << "createRandomFloatAllocation(mRS, Element.DataType." << dataType << ", " 329 << vectorSize << ", " << seed << ", " << param.minValue << ", " << param.maxValue 330 << ")"; 331 } else { 332 /* TODO Instead of passing always false, check whether we are doing a limited test. 333 * Use instead: (mPermutation.getTest() == "limited" ? "false" : "true") 334 */ 335 *mJava << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize 336 << ", " << seed << ", false)"; 337 } 338 *mJava << ";\n"; 339} 340 341void PermutationWriter::writeJavaRandomCompatibleFloatAllocation( 342 const string& dataType, const string& seed, char vectorSize, 343 const NumericalType& compatibleType, const NumericalType& generatedType) const { 344 *mJava << "createRandomFloatAllocation" 345 << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", "; 346 double minValue = 0.0; 347 double maxValue = 0.0; 348 switch (compatibleType.kind) { 349 case FLOATING_POINT: { 350 // We're generating floating point values. We just worry about the exponent. 351 // Subtract 1 for the exponent sign. 352 int bits = min(compatibleType.exponentBits, generatedType.exponentBits) - 1; 353 maxValue = ldexp(0.95, (1 << bits) - 1); 354 minValue = -maxValue; 355 break; 356 } 357 case UNSIGNED_INTEGER: 358 maxValue = maxDoubleForInteger(compatibleType.significantBits, 359 generatedType.significantBits); 360 minValue = 0.0; 361 break; 362 case SIGNED_INTEGER: 363 maxValue = maxDoubleForInteger(compatibleType.significantBits, 364 generatedType.significantBits); 365 minValue = -maxValue - 1.0; 366 break; 367 } 368 *mJava << scientific << std::setprecision(19); 369 *mJava << minValue << ", " << maxValue << ")"; 370 mJava->unsetf(ios_base::floatfield); 371} 372 373void PermutationWriter::writeJavaRandomCompatibleIntegerAllocation( 374 const string& dataType, const string& seed, char vectorSize, 375 const NumericalType& compatibleType, const NumericalType& generatedType) const { 376 *mJava << "createRandomIntegerAllocation" 377 << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", "; 378 379 if (compatibleType.kind == FLOATING_POINT) { 380 // Currently, all floating points can take any number we generate. 381 bool isSigned = generatedType.kind == SIGNED_INTEGER; 382 *mJava << (isSigned ? "true" : "false") << ", " << generatedType.significantBits; 383 } else { 384 bool isSigned = 385 compatibleType.kind == SIGNED_INTEGER && generatedType.kind == SIGNED_INTEGER; 386 *mJava << (isSigned ? "true" : "false") << ", " 387 << min(compatibleType.significantBits, generatedType.significantBits); 388 } 389 *mJava << ")"; 390} 391 392void PermutationWriter::writeJavaOutputAllocationDefinition( 393 const ParameterDefinition& param) const { 394 string dataType; 395 char vectorSize; 396 convertToRsType(param.rsType, &dataType, &vectorSize); 397 mJava->indent() << "Allocation " << param.javaAllocName << " = Allocation.createSized(mRS, " 398 << "getElement(mRS, Element.DataType." << dataType << ", " << vectorSize 399 << "), INPUTSIZE);\n"; 400} 401 402void PermutationWriter::writeJavaVerifyScalarMethod(bool verifierValidates) const { 403 writeJavaVerifyMethodHeader(); 404 mJava->startBlock(); 405 406 string vectorSize = "1"; 407 for (auto p : mAllInputsAndOutputs) { 408 writeJavaArrayInitialization(*p); 409 if (p->mVectorSize != "1" && p->mVectorSize != vectorSize) { 410 if (vectorSize == "1") { 411 vectorSize = p->mVectorSize; 412 } else { 413 cerr << "Error. Had vector " << vectorSize << " and " << p->mVectorSize << "\n"; 414 } 415 } 416 } 417 418 mJava->indent() << "StringBuilder message = new StringBuilder();\n"; 419 mJava->indent() << "boolean errorFound = false;\n"; 420 mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)"; 421 mJava->startBlock(); 422 423 mJava->indent() << "for (int j = 0; j < " << vectorSize << " ; j++)"; 424 mJava->startBlock(); 425 426 mJava->indent() << "// Extract the inputs.\n"; 427 mJava->indent() << mJavaArgumentsClassName << " args = new " << mJavaArgumentsClassName 428 << "();\n"; 429 for (auto p : mAllInputsAndOutputs) { 430 if (!p->isOutParameter) { 431 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i"; 432 if (p->vectorWidth != "1") { 433 *mJava << " * " << p->vectorWidth << " + j"; 434 } 435 *mJava << "];\n"; 436 437 // Convert the Float16 parameter to double and store it in the appropriate field in the 438 // Arguments class. 439 if (p->isFloat16Parameter()) { 440 mJava->indent() << "args." << p->doubleVariableName 441 << " = Float16Utils.convertFloat16ToDouble(args." 442 << p->variableName << ");\n"; 443 } 444 } 445 } 446 const bool hasFloat = mPermutation.hasFloatAnswers(); 447 if (verifierValidates) { 448 mJava->indent() << "// Extract the outputs.\n"; 449 for (auto p : mAllInputsAndOutputs) { 450 if (p->isOutParameter) { 451 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName 452 << "[i * " << p->vectorWidth << " + j];\n"; 453 if (p->isFloat16Parameter()) { 454 mJava->indent() << "args." << p->doubleVariableName 455 << " = Float16Utils.convertFloat16ToDouble(args." 456 << p->variableName << ");\n"; 457 } 458 } 459 } 460 mJava->indent() << "// Ask the CoreMathVerifier to validate.\n"; 461 if (hasFloat) { 462 writeJavaCreateTarget(); 463 } 464 mJava->indent() << "String errorMessage = CoreMathVerifier." 465 << mJavaVerifierVerifyMethodName << "(args"; 466 if (hasFloat) { 467 *mJava << ", target"; 468 } 469 *mJava << ");\n"; 470 mJava->indent() << "boolean valid = errorMessage == null;\n"; 471 } else { 472 mJava->indent() << "// Figure out what the outputs should have been.\n"; 473 if (hasFloat) { 474 writeJavaCreateTarget(); 475 } 476 mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args"; 477 if (hasFloat) { 478 *mJava << ", target"; 479 } 480 *mJava << ");\n"; 481 mJava->indent() << "// Validate the outputs.\n"; 482 mJava->indent() << "boolean valid = true;\n"; 483 for (auto p : mAllInputsAndOutputs) { 484 if (p->isOutParameter) { 485 writeJavaTestAndSetValid(*p, "", "[i * " + p->vectorWidth + " + j]"); 486 } 487 } 488 } 489 490 mJava->indent() << "if (!valid)"; 491 mJava->startBlock(); 492 mJava->indent() << "if (!errorFound)"; 493 mJava->startBlock(); 494 mJava->indent() << "errorFound = true;\n"; 495 496 for (auto p : mAllInputsAndOutputs) { 497 if (p->isOutParameter) { 498 writeJavaAppendOutputToMessage(*p, "", "[i * " + p->vectorWidth + " + j]", 499 verifierValidates); 500 } else { 501 writeJavaAppendInputToMessage(*p, "args." + p->variableName); 502 } 503 } 504 if (verifierValidates) { 505 mJava->indent() << "message.append(errorMessage);\n"; 506 } 507 mJava->indent() << "message.append(\"Errors at\");\n"; 508 mJava->endBlock(); 509 510 mJava->indent() << "message.append(\" [\");\n"; 511 mJava->indent() << "message.append(Integer.toString(i));\n"; 512 mJava->indent() << "message.append(\", \");\n"; 513 mJava->indent() << "message.append(Integer.toString(j));\n"; 514 mJava->indent() << "message.append(\"]\");\n"; 515 516 mJava->endBlock(); 517 mJava->endBlock(); 518 mJava->endBlock(); 519 520 mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n"; 521 mJava->indentPlus() 522 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n"; 523 524 mJava->endBlock(); 525 *mJava << "\n"; 526} 527 528void PermutationWriter::writeJavaVerifyVectorMethod() const { 529 writeJavaVerifyMethodHeader(); 530 mJava->startBlock(); 531 532 for (auto p : mAllInputsAndOutputs) { 533 writeJavaArrayInitialization(*p); 534 } 535 mJava->indent() << "StringBuilder message = new StringBuilder();\n"; 536 mJava->indent() << "boolean errorFound = false;\n"; 537 mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)"; 538 mJava->startBlock(); 539 540 mJava->indent() << mJavaArgumentsNClassName << " args = new " << mJavaArgumentsNClassName 541 << "();\n"; 542 543 mJava->indent() << "// Create the appropriate sized arrays in args\n"; 544 for (auto p : mAllInputsAndOutputs) { 545 if (p->mVectorSize != "1") { 546 string type = p->javaBaseType; 547 if (p->isOutParameter && p->isFloatType) { 548 type = "Target.Floaty"; 549 } 550 mJava->indent() << "args." << p->variableName << " = new " << type << "[" 551 << p->mVectorSize << "];\n"; 552 if (p->isFloat16Parameter() && !p->isOutParameter) { 553 mJava->indent() << "args." << p->variableName << "Double = new double[" 554 << p->mVectorSize << "];\n"; 555 } 556 } 557 } 558 559 mJava->indent() << "// Fill args with the input values\n"; 560 for (auto p : mAllInputsAndOutputs) { 561 if (!p->isOutParameter) { 562 if (p->mVectorSize == "1") { 563 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i]" 564 << ";\n"; 565 // Convert the Float16 parameter to double and store it in the appropriate field in 566 // the Arguments class. 567 if (p->isFloat16Parameter()) { 568 mJava->indent() << "args." << p->doubleVariableName << " = " 569 << "Float16Utils.convertFloat16ToDouble(args." 570 << p->variableName << ");\n"; 571 } 572 } else { 573 mJava->indent() << "for (int j = 0; j < " << p->mVectorSize << " ; j++)"; 574 mJava->startBlock(); 575 mJava->indent() << "args." << p->variableName << "[j] = " 576 << p->javaArrayName << "[i * " << p->vectorWidth << " + j]" 577 << ";\n"; 578 579 // Convert the Float16 parameter to double and store it in the appropriate field in 580 // the Arguments class. 581 if (p->isFloat16Parameter()) { 582 mJava->indent() << "args." << p->doubleVariableName << "[j] = " 583 << "Float16Utils.convertFloat16ToDouble(args." 584 << p->variableName << "[j]);\n"; 585 } 586 mJava->endBlock(); 587 } 588 } 589 } 590 writeJavaCreateTarget(); 591 mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName 592 << "(args, target);\n\n"; 593 594 mJava->indent() << "// Compare the expected outputs to the actual values returned by RS.\n"; 595 mJava->indent() << "boolean valid = true;\n"; 596 for (auto p : mAllInputsAndOutputs) { 597 if (p->isOutParameter) { 598 writeJavaVectorComparison(*p); 599 } 600 } 601 602 mJava->indent() << "if (!valid)"; 603 mJava->startBlock(); 604 mJava->indent() << "if (!errorFound)"; 605 mJava->startBlock(); 606 mJava->indent() << "errorFound = true;\n"; 607 608 for (auto p : mAllInputsAndOutputs) { 609 if (p->isOutParameter) { 610 writeJavaAppendVectorOutputToMessage(*p); 611 } else { 612 writeJavaAppendVectorInputToMessage(*p); 613 } 614 } 615 mJava->indent() << "message.append(\"Errors at\");\n"; 616 mJava->endBlock(); 617 618 mJava->indent() << "message.append(\" [\");\n"; 619 mJava->indent() << "message.append(Integer.toString(i));\n"; 620 mJava->indent() << "message.append(\"]\");\n"; 621 622 mJava->endBlock(); 623 mJava->endBlock(); 624 625 mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n"; 626 mJava->indentPlus() 627 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n"; 628 629 mJava->endBlock(); 630 *mJava << "\n"; 631} 632 633 634void PermutationWriter::writeJavaCreateTarget() const { 635 string name = mPermutation.getName(); 636 637 const char* functionType = "NORMAL"; 638 size_t end = name.find('_'); 639 if (end != string::npos) { 640 if (name.compare(0, end, "native") == 0) { 641 functionType = "NATIVE"; 642 } else if (name.compare(0, end, "half") == 0) { 643 functionType = "HALF"; 644 } else if (name.compare(0, end, "fast") == 0) { 645 functionType = "FAST"; 646 } 647 } 648 649 string floatType = mReturnParam->specType; 650 const char* precisionStr = ""; 651 if (floatType.compare("f16") == 0) { 652 precisionStr = "HALF"; 653 } else if (floatType.compare("f32") == 0) { 654 precisionStr = "FLOAT"; 655 } else if (floatType.compare("f64") == 0) { 656 precisionStr = "DOUBLE"; 657 } else { 658 cerr << "Error. Unreachable. Return type is not floating point\n"; 659 } 660 661 mJava->indent() << "Target target = new Target(Target.FunctionType." << 662 functionType << ", Target.ReturnType." << precisionStr << 663 ", relaxed);\n"; 664} 665 666void PermutationWriter::writeJavaVerifyMethodHeader() const { 667 mJava->indent() << "private void " << mJavaVerifyMethodName << "("; 668 for (auto p : mAllInputsAndOutputs) { 669 *mJava << "Allocation " << p->javaAllocName << ", "; 670 } 671 *mJava << "boolean relaxed)"; 672} 673 674void PermutationWriter::writeJavaArrayInitialization(const ParameterDefinition& p) const { 675 mJava->indent() << p.javaBaseType << "[] " << p.javaArrayName << " = new " << p.javaBaseType 676 << "[INPUTSIZE * " << p.vectorWidth << "];\n"; 677 678 /* For basic types, populate the array with values, to help understand failures. We have had 679 * bugs where the output buffer was all 0. We were not sure if there was a failed copy or 680 * the GPU driver was copying zeroes. 681 */ 682 if (p.typeIndex >= 0) { 683 mJava->indent() << "Arrays.fill(" << p.javaArrayName << ", (" << TYPES[p.typeIndex].javaType 684 << ") 42);\n"; 685 } 686 687 mJava->indent() << p.javaAllocName << ".copyTo(" << p.javaArrayName << ");\n"; 688} 689 690void PermutationWriter::writeJavaTestAndSetValid(const ParameterDefinition& p, 691 const string& argsIndex, 692 const string& actualIndex) const { 693 writeJavaTestOneValue(p, argsIndex, actualIndex); 694 mJava->startBlock(); 695 mJava->indent() << "valid = false;\n"; 696 mJava->endBlock(); 697} 698 699void PermutationWriter::writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex, 700 const string& actualIndex) const { 701 string actualOut; 702 if (p.isFloat16Parameter()) { 703 // For Float16 values, the output needs to be converted to Double. 704 actualOut = "Float16Utils.convertFloat16ToDouble(" + p.javaArrayName + actualIndex + ")"; 705 } else { 706 actualOut = p.javaArrayName + actualIndex; 707 } 708 709 mJava->indent() << "if ("; 710 if (p.isFloatType) { 711 *mJava << "!args." << p.variableName << argsIndex << ".couldBe(" << actualOut; 712 const string s = mPermutation.getPrecisionLimit(); 713 if (!s.empty()) { 714 *mJava << ", " << s; 715 } 716 *mJava << ")"; 717 } else { 718 *mJava << "args." << p.variableName << argsIndex << " != " << p.javaArrayName 719 << actualIndex; 720 } 721 722 if (p.undefinedIfOutIsNan && mReturnParam) { 723 *mJava << " && !args." << mReturnParam->variableName << argsIndex << ".isNaN()"; 724 } 725 *mJava << ")"; 726} 727 728void PermutationWriter::writeJavaVectorComparison(const ParameterDefinition& p) const { 729 if (p.mVectorSize == "1") { 730 writeJavaTestAndSetValid(p, "", "[i]"); 731 } else { 732 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)"; 733 mJava->startBlock(); 734 writeJavaTestAndSetValid(p, "[j]", "[i * " + p.vectorWidth + " + j]"); 735 mJava->endBlock(); 736 } 737} 738 739void PermutationWriter::writeJavaAppendOutputToMessage(const ParameterDefinition& p, 740 const string& argsIndex, 741 const string& actualIndex, 742 bool verifierValidates) const { 743 if (verifierValidates) { 744 mJava->indent() << "message.append(\"Output " << p.variableName << ": \");\n"; 745 mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex 746 << ");\n"; 747 writeJavaAppendNewLineToMessage(); 748 if (p.isFloat16Parameter()) { 749 writeJavaAppendNewLineToMessage(); 750 mJava->indent() << "message.append(\"Output " << p.variableName 751 << " (in double): \");\n"; 752 mJava->indent() << "appendVariableToMessage(message, args." << p.doubleVariableName 753 << ");\n"; 754 writeJavaAppendNewLineToMessage(); 755 } 756 } else { 757 mJava->indent() << "message.append(\"Expected output " << p.variableName << ": \");\n"; 758 mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex 759 << ");\n"; 760 writeJavaAppendNewLineToMessage(); 761 762 mJava->indent() << "message.append(\"Actual output " << p.variableName << ": \");\n"; 763 mJava->indent() << "appendVariableToMessage(message, " << p.javaArrayName << actualIndex 764 << ");\n"; 765 766 if (p.isFloat16Parameter()) { 767 writeJavaAppendNewLineToMessage(); 768 mJava->indent() << "message.append(\"Actual output " << p.variableName 769 << " (in double): \");\n"; 770 mJava->indent() << "appendVariableToMessage(message, Float16Utils.convertFloat16ToDouble(" 771 << p.javaArrayName << actualIndex << "));\n"; 772 } 773 774 writeJavaTestOneValue(p, argsIndex, actualIndex); 775 mJava->startBlock(); 776 mJava->indent() << "message.append(\" FAIL\");\n"; 777 mJava->endBlock(); 778 writeJavaAppendNewLineToMessage(); 779 } 780} 781 782void PermutationWriter::writeJavaAppendInputToMessage(const ParameterDefinition& p, 783 const string& actual) const { 784 mJava->indent() << "message.append(\"Input " << p.variableName << ": \");\n"; 785 mJava->indent() << "appendVariableToMessage(message, " << actual << ");\n"; 786 writeJavaAppendNewLineToMessage(); 787} 788 789void PermutationWriter::writeJavaAppendNewLineToMessage() const { 790 mJava->indent() << "message.append(\"\\n\");\n"; 791} 792 793void PermutationWriter::writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const { 794 if (p.mVectorSize == "1") { 795 writeJavaAppendInputToMessage(p, p.javaArrayName + "[i]"); 796 } else { 797 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)"; 798 mJava->startBlock(); 799 writeJavaAppendInputToMessage(p, p.javaArrayName + "[i * " + p.vectorWidth + " + j]"); 800 mJava->endBlock(); 801 } 802} 803 804void PermutationWriter::writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const { 805 if (p.mVectorSize == "1") { 806 writeJavaAppendOutputToMessage(p, "", "[i]", false); 807 } else { 808 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)"; 809 mJava->startBlock(); 810 writeJavaAppendOutputToMessage(p, "[j]", "[i * " + p.vectorWidth + " + j]", false); 811 mJava->endBlock(); 812 } 813} 814 815void PermutationWriter::writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const { 816 string script = "script"; 817 if (relaxed) { 818 script += "Relaxed"; 819 } 820 821 mJava->indent() << "try"; 822 mJava->startBlock(); 823 824 for (auto p : mAllInputsAndOutputs) { 825 if (p->isOutParameter) { 826 writeJavaOutputAllocationDefinition(*p); 827 } 828 } 829 830 for (auto p : mPermutation.getParams()) { 831 if (p != mFirstInputParam) { 832 mJava->indent() << script << ".set_" << p->rsAllocName << "(" << p->javaAllocName 833 << ");\n"; 834 } 835 } 836 837 mJava->indent() << script << ".forEach_" << mRsKernelName << "("; 838 bool needComma = false; 839 if (mFirstInputParam) { 840 *mJava << mFirstInputParam->javaAllocName; 841 needComma = true; 842 } 843 if (mReturnParam) { 844 if (needComma) { 845 *mJava << ", "; 846 } 847 *mJava << mReturnParam->variableName << ");\n"; 848 } 849 850 if (generateCallToVerifier) { 851 mJava->indent() << mJavaVerifyMethodName << "("; 852 for (auto p : mAllInputsAndOutputs) { 853 *mJava << p->variableName << ", "; 854 } 855 856 if (relaxed) { 857 *mJava << "true"; 858 } else { 859 *mJava << "false"; 860 } 861 *mJava << ");\n"; 862 } 863 mJava->decreaseIndent(); 864 mJava->indent() << "} catch (Exception e) {\n"; 865 mJava->increaseIndent(); 866 mJava->indent() << "throw new RSRuntimeException(\"RenderScript. Can't invoke forEach_" 867 << mRsKernelName << ": \" + e.toString());\n"; 868 mJava->endBlock(); 869} 870 871/* Write the section of the .rs file for this permutation. 872 * 873 * We communicate the extra input and output parameters via global allocations. 874 * For example, if we have a function that takes three arguments, two for input 875 * and one for output: 876 * 877 * start: 878 * name: gamn 879 * ret: float3 880 * arg: float3 a 881 * arg: int b 882 * arg: float3 *c 883 * end: 884 * 885 * We'll produce: 886 * 887 * rs_allocation gAllocInB; 888 * rs_allocation gAllocOutC; 889 * 890 * float3 __attribute__((kernel)) test_gamn_float3_int_float3(float3 inA, unsigned int x) { 891 * int inB; 892 * float3 outC; 893 * float2 out; 894 * inB = rsGetElementAt_int(gAllocInB, x); 895 * out = gamn(a, in_b, &outC); 896 * rsSetElementAt_float4(gAllocOutC, &outC, x); 897 * return out; 898 * } 899 * 900 * We avoid re-using x and y from the definition because these have reserved 901 * meanings in a .rs file. 902 */ 903void PermutationWriter::writeRsSection(set<string>* rsAllocationsGenerated) const { 904 // Write the allocation declarations we'll need. 905 for (auto p : mPermutation.getParams()) { 906 // Don't need allocation for one input and one return value. 907 if (p != mFirstInputParam) { 908 writeRsAllocationDefinition(*p, rsAllocationsGenerated); 909 } 910 } 911 *mRs << "\n"; 912 913 // Write the function header. 914 if (mReturnParam) { 915 *mRs << mReturnParam->rsType; 916 } else { 917 *mRs << "void"; 918 } 919 *mRs << " __attribute__((kernel)) " << mRsKernelName; 920 *mRs << "("; 921 bool needComma = false; 922 if (mFirstInputParam) { 923 *mRs << mFirstInputParam->rsType << " " << mFirstInputParam->variableName; 924 needComma = true; 925 } 926 if (mPermutation.getOutputCount() > 1 || mPermutation.getInputCount() > 1) { 927 if (needComma) { 928 *mRs << ", "; 929 } 930 *mRs << "unsigned int x"; 931 } 932 *mRs << ")"; 933 mRs->startBlock(); 934 935 // Write the local variable declarations and initializations. 936 for (auto p : mPermutation.getParams()) { 937 if (p == mFirstInputParam) { 938 continue; 939 } 940 mRs->indent() << p->rsType << " " << p->variableName; 941 if (p->isOutParameter) { 942 *mRs << " = 0;\n"; 943 } else { 944 *mRs << " = rsGetElementAt_" << p->rsType << "(" << p->rsAllocName << ", x);\n"; 945 } 946 } 947 948 // Write the function call. 949 if (mReturnParam) { 950 if (mPermutation.getOutputCount() > 1) { 951 mRs->indent() << mReturnParam->rsType << " " << mReturnParam->variableName << " = "; 952 } else { 953 mRs->indent() << "return "; 954 } 955 } 956 *mRs << mPermutation.getName() << "("; 957 needComma = false; 958 for (auto p : mPermutation.getParams()) { 959 if (needComma) { 960 *mRs << ", "; 961 } 962 if (p->isOutParameter) { 963 *mRs << "&"; 964 } 965 *mRs << p->variableName; 966 needComma = true; 967 } 968 *mRs << ");\n"; 969 970 if (mPermutation.getOutputCount() > 1) { 971 // Write setting the extra out parameters into the allocations. 972 for (auto p : mPermutation.getParams()) { 973 if (p->isOutParameter) { 974 mRs->indent() << "rsSetElementAt_" << p->rsType << "(" << p->rsAllocName << ", "; 975 // Check if we need to use '&' for this type of argument. 976 char lastChar = p->variableName.back(); 977 if (lastChar >= '0' && lastChar <= '9') { 978 *mRs << "&"; 979 } 980 *mRs << p->variableName << ", x);\n"; 981 } 982 } 983 if (mReturnParam) { 984 mRs->indent() << "return " << mReturnParam->variableName << ";\n"; 985 } 986 } 987 mRs->endBlock(); 988} 989 990void PermutationWriter::writeRsAllocationDefinition(const ParameterDefinition& param, 991 set<string>* rsAllocationsGenerated) const { 992 if (!testAndSet(param.rsAllocName, rsAllocationsGenerated)) { 993 *mRs << "rs_allocation " << param.rsAllocName << ";\n"; 994 } 995} 996 997// Open the mJavaFile and writes the header. 998static bool startJavaFile(GeneratedFile* file, const Function& function, const string& directory, 999 const string& testName, const string& relaxedTestName) { 1000 const string fileName = testName + ".java"; 1001 if (!file->start(directory, fileName)) { 1002 return false; 1003 } 1004 file->writeNotices(); 1005 1006 *file << "package android.renderscript.cts;\n\n"; 1007 1008 *file << "import android.renderscript.Allocation;\n"; 1009 *file << "import android.renderscript.RSRuntimeException;\n"; 1010 *file << "import android.renderscript.Element;\n"; 1011 *file << "import android.renderscript.cts.Target;\n\n"; 1012 *file << "import java.util.Arrays;\n\n"; 1013 1014 *file << "public class " << testName << " extends RSBaseCompute"; 1015 file->startBlock(); // The corresponding endBlock() is in finishJavaFile() 1016 *file << "\n"; 1017 1018 file->indent() << "private ScriptC_" << testName << " script;\n"; 1019 file->indent() << "private ScriptC_" << relaxedTestName << " scriptRelaxed;\n\n"; 1020 1021 file->indent() << "@Override\n"; 1022 file->indent() << "protected void setUp() throws Exception"; 1023 file->startBlock(); 1024 1025 file->indent() << "super.setUp();\n"; 1026 file->indent() << "script = new ScriptC_" << testName << "(mRS);\n"; 1027 file->indent() << "scriptRelaxed = new ScriptC_" << relaxedTestName << "(mRS);\n"; 1028 1029 file->endBlock(); 1030 *file << "\n"; 1031 return true; 1032} 1033 1034// Write the test method that calls all the generated Check methods. 1035static void finishJavaFile(GeneratedFile* file, const Function& function, 1036 const vector<string>& javaCheckMethods) { 1037 file->indent() << "public void test" << function.getCapitalizedName() << "()"; 1038 file->startBlock(); 1039 for (auto m : javaCheckMethods) { 1040 file->indent() << m << "();\n"; 1041 } 1042 file->endBlock(); 1043 1044 file->endBlock(); 1045} 1046 1047// Open the script file and write its header. 1048static bool startRsFile(GeneratedFile* file, const Function& function, const string& directory, 1049 const string& testName) { 1050 string fileName = testName + ".rs"; 1051 if (!file->start(directory, fileName)) { 1052 return false; 1053 } 1054 file->writeNotices(); 1055 1056 *file << "#pragma version(1)\n"; 1057 *file << "#pragma rs java_package_name(android.renderscript.cts)\n\n"; 1058 return true; 1059} 1060 1061// Write the entire *Relaxed.rs test file, as it only depends on the name. 1062static bool writeRelaxedRsFile(const Function& function, const string& directory, 1063 const string& testName, const string& relaxedTestName) { 1064 string name = relaxedTestName + ".rs"; 1065 1066 GeneratedFile file; 1067 if (!file.start(directory, name)) { 1068 return false; 1069 } 1070 file.writeNotices(); 1071 1072 file << "#include \"" << testName << ".rs\"\n"; 1073 file << "#pragma rs_fp_relaxed\n"; 1074 file.close(); 1075 return true; 1076} 1077 1078/* Write the .java and the two .rs test files. versionOfTestFiles is used to restrict which API 1079 * to test. 1080 */ 1081static bool writeTestFilesForFunction(const Function& function, const string& directory, 1082 unsigned int versionOfTestFiles) { 1083 // Avoid creating empty files if we're not testing this function. 1084 if (!needTestFiles(function, versionOfTestFiles)) { 1085 return true; 1086 } 1087 1088 const string testName = "Test" + function.getCapitalizedName(); 1089 const string relaxedTestName = testName + "Relaxed"; 1090 1091 if (!writeRelaxedRsFile(function, directory, testName, relaxedTestName)) { 1092 return false; 1093 } 1094 1095 GeneratedFile rsFile; // The Renderscript test file we're generating. 1096 GeneratedFile javaFile; // The Jave test file we're generating. 1097 if (!startRsFile(&rsFile, function, directory, testName)) { 1098 return false; 1099 } 1100 1101 if (!startJavaFile(&javaFile, function, directory, testName, relaxedTestName)) { 1102 return false; 1103 } 1104 1105 /* We keep track of the allocations generated in the .rs file and the argument classes defined 1106 * in the Java file, as we share these between the functions created for each specification. 1107 */ 1108 set<string> rsAllocationsGenerated; 1109 set<string> javaGeneratedArgumentClasses; 1110 // Lines of Java code to invoke the check methods. 1111 vector<string> javaCheckMethods; 1112 1113 for (auto spec : function.getSpecifications()) { 1114 if (spec->hasTests(versionOfTestFiles)) { 1115 for (auto permutation : spec->getPermutations()) { 1116 PermutationWriter w(*permutation, &rsFile, &javaFile); 1117 w.writeRsSection(&rsAllocationsGenerated); 1118 w.writeJavaSection(&javaGeneratedArgumentClasses); 1119 1120 // Store the check method to be called. 1121 javaCheckMethods.push_back(w.getJavaCheckMethodName()); 1122 } 1123 } 1124 } 1125 1126 finishJavaFile(&javaFile, function, javaCheckMethods); 1127 // There's no work to wrap-up in the .rs file. 1128 1129 rsFile.close(); 1130 javaFile.close(); 1131 return true; 1132} 1133 1134bool generateTestFiles(const string& directory, unsigned int versionOfTestFiles) { 1135 bool success = true; 1136 for (auto f : systemSpecification.getFunctions()) { 1137 if (!writeTestFilesForFunction(*f.second, directory, versionOfTestFiles)) { 1138 success = false; 1139 } 1140 } 1141 return success; 1142} 1143