1/* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include <iomanip> 18#include <iostream> 19#include <cmath> 20#include <sstream> 21 22#include "Generator.h" 23#include "Specification.h" 24#include "Utilities.h" 25 26using namespace std; 27 28// Converts float2 to FLOAT_32 and 2, etc. 29static void convertToRsType(const string& name, string* dataType, char* vectorSize) { 30 string s = name; 31 int last = s.size() - 1; 32 char lastChar = s[last]; 33 if (lastChar >= '1' && lastChar <= '4') { 34 s.erase(last); 35 *vectorSize = lastChar; 36 } else { 37 *vectorSize = '1'; 38 } 39 dataType->clear(); 40 for (int i = 0; i < NUM_TYPES; i++) { 41 if (s == TYPES[i].cType) { 42 *dataType = TYPES[i].rsDataType; 43 break; 44 } 45 } 46} 47 48// Returns true if any permutation of the function have tests to b 49static bool needTestFiles(const Function& function, unsigned int versionOfTestFiles) { 50 for (auto spec : function.getSpecifications()) { 51 if (spec->hasTests(versionOfTestFiles)) { 52 return true; 53 } 54 } 55 return false; 56} 57 58/* One instance of this class is generated for each permutation of a function for which 59 * we are generating test code. This instance will generate both the script and the Java 60 * section of the test files for this permutation. The class is mostly used to keep track 61 * of the various names shared between script and Java files. 62 * WARNING: Because the constructor keeps a reference to the FunctionPermutation, PermutationWriter 63 * should not exceed the lifetime of FunctionPermutation. 64 */ 65class PermutationWriter { 66private: 67 FunctionPermutation& mPermutation; 68 69 string mRsKernelName; 70 string mJavaArgumentsClassName; 71 string mJavaArgumentsNClassName; 72 string mJavaVerifierComputeMethodName; 73 string mJavaVerifierVerifyMethodName; 74 string mJavaCheckMethodName; 75 string mJavaVerifyMethodName; 76 77 // Pointer to the files we are generating. Handy to avoid always passing them in the calls. 78 GeneratedFile* mRs; 79 GeneratedFile* mJava; 80 81 /* Shortcuts to the return parameter and the first input parameter of the function 82 * specification. 83 */ 84 const ParameterDefinition* mReturnParam; // Can be nullptr. NOT OWNED. 85 const ParameterDefinition* mFirstInputParam; // Can be nullptr. NOT OWNED. 86 87 /* All the parameters plus the return param, if present. Collecting them together 88 * simplifies code generation. NOT OWNED. 89 */ 90 vector<const ParameterDefinition*> mAllInputsAndOutputs; 91 92 /* We use a class to pass the arguments between the generated code and the CoreVerifier. This 93 * method generates this class. The set keeps track if we've generated this class already 94 * for this test file, as more than one permutation may use the same argument class. 95 */ 96 void writeJavaArgumentClass(bool scalar, set<string>* javaGeneratedArgumentClasses) const; 97 98 // Generate the Check* method that invokes the script and calls the verifier. 99 void writeJavaCheckMethod(bool generateCallToVerifier) const; 100 101 // Generate code to define and randomly initialize the input allocation. 102 void writeJavaInputAllocationDefinition(const ParameterDefinition& param) const; 103 104 /* Generate code that instantiate an allocation of floats or integers and fills it with 105 * random data. This random data must be compatible with the specified type. This is 106 * used for the convert_* tests, as converting values that don't fit yield undefined results. 107 */ 108 void writeJavaRandomCompatibleFloatAllocation(const string& dataType, const string& seed, 109 char vectorSize, 110 const NumericalType& compatibleType, 111 const NumericalType& generatedType) const; 112 void writeJavaRandomCompatibleIntegerAllocation(const string& dataType, const string& seed, 113 char vectorSize, 114 const NumericalType& compatibleType, 115 const NumericalType& generatedType) const; 116 117 // Generate code that defines an output allocation. 118 void writeJavaOutputAllocationDefinition(const ParameterDefinition& param) const; 119 120 /* Generate the code that verifies the results for RenderScript functions where each entry 121 * of a vector is evaluated independently. If verifierValidates is true, CoreMathVerifier 122 * does the actual validation instead of more commonly returning the range of acceptable values. 123 */ 124 void writeJavaVerifyScalarMethod(bool verifierValidates) const; 125 126 /* Generate the code that verify the results for a RenderScript function where a vector 127 * is a point in n-dimensional space. 128 */ 129 void writeJavaVerifyVectorMethod() const; 130 131 // Generate the line that creates the Target. 132 void writeJavaCreateTarget() const; 133 134 // Generate the method header of the verify function. 135 void writeJavaVerifyMethodHeader() const; 136 137 // Generate codes that copies the content of an allocation to an array. 138 void writeJavaArrayInitialization(const ParameterDefinition& p) const; 139 140 // Generate code that tests one value returned from the script. 141 void writeJavaTestAndSetValid(const ParameterDefinition& p, const string& argsIndex, 142 const string& actualIndex) const; 143 void writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex, 144 const string& actualIndex) const; 145 // For test:vector cases, generate code that compares returned vector vs. expected value. 146 void writeJavaVectorComparison(const ParameterDefinition& p) const; 147 148 // Muliple functions that generates code to build the error message if an error is found. 149 void writeJavaAppendOutputToMessage(const ParameterDefinition& p, const string& argsIndex, 150 const string& actualIndex, bool verifierValidates) const; 151 void writeJavaAppendInputToMessage(const ParameterDefinition& p, const string& actual) const; 152 void writeJavaAppendNewLineToMessage() const; 153 void writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const; 154 void writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const; 155 156 // Generate the set of instructions to call the script. 157 void writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const; 158 159 // Write an allocation definition if not already emitted in the .rs file. 160 void writeRsAllocationDefinition(const ParameterDefinition& param, 161 set<string>* rsAllocationsGenerated) const; 162 163public: 164 /* NOTE: We keep pointers to the permutation and the files. This object should not 165 * outlive the arguments. 166 */ 167 PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile, 168 GeneratedFile* javaFile); 169 string getJavaCheckMethodName() const { return mJavaCheckMethodName; } 170 171 // Write the script test function for this permutation. 172 void writeRsSection(set<string>* rsAllocationsGenerated) const; 173 // Write the section of the Java code that calls the script and validates the results 174 void writeJavaSection(set<string>* javaGeneratedArgumentClasses) const; 175}; 176 177PermutationWriter::PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile, 178 GeneratedFile* javaFile) 179 : mPermutation(permutation), 180 mRs(rsFile), 181 mJava(javaFile), 182 mReturnParam(nullptr), 183 mFirstInputParam(nullptr) { 184 mRsKernelName = "test" + capitalize(permutation.getName()); 185 186 mJavaArgumentsClassName = "Arguments"; 187 mJavaArgumentsNClassName = "Arguments"; 188 const string trunk = capitalize(permutation.getNameTrunk()); 189 mJavaCheckMethodName = "check" + trunk; 190 mJavaVerifyMethodName = "verifyResults" + trunk; 191 192 for (auto p : permutation.getParams()) { 193 mAllInputsAndOutputs.push_back(p); 194 if (mFirstInputParam == nullptr && !p->isOutParameter) { 195 mFirstInputParam = p; 196 } 197 } 198 mReturnParam = permutation.getReturn(); 199 if (mReturnParam) { 200 mAllInputsAndOutputs.push_back(mReturnParam); 201 } 202 203 for (auto p : mAllInputsAndOutputs) { 204 const string capitalizedRsType = capitalize(p->rsType); 205 const string capitalizedBaseType = capitalize(p->rsBaseType); 206 mRsKernelName += capitalizedRsType; 207 mJavaArgumentsClassName += capitalizedBaseType; 208 mJavaArgumentsNClassName += capitalizedBaseType; 209 if (p->mVectorSize != "1") { 210 mJavaArgumentsNClassName += "N"; 211 } 212 mJavaCheckMethodName += capitalizedRsType; 213 mJavaVerifyMethodName += capitalizedRsType; 214 } 215 mJavaVerifierComputeMethodName = "compute" + trunk; 216 mJavaVerifierVerifyMethodName = "verify" + trunk; 217} 218 219void PermutationWriter::writeJavaSection(set<string>* javaGeneratedArgumentClasses) const { 220 // By default, we test the results using item by item comparison. 221 const string test = mPermutation.getTest(); 222 if (test == "scalar" || test == "limited") { 223 writeJavaArgumentClass(true, javaGeneratedArgumentClasses); 224 writeJavaCheckMethod(true); 225 writeJavaVerifyScalarMethod(false); 226 } else if (test == "custom") { 227 writeJavaArgumentClass(true, javaGeneratedArgumentClasses); 228 writeJavaCheckMethod(true); 229 writeJavaVerifyScalarMethod(true); 230 } else if (test == "vector") { 231 writeJavaArgumentClass(false, javaGeneratedArgumentClasses); 232 writeJavaCheckMethod(true); 233 writeJavaVerifyVectorMethod(); 234 } else if (test == "noverify") { 235 writeJavaCheckMethod(false); 236 } 237} 238 239void PermutationWriter::writeJavaArgumentClass(bool scalar, 240 set<string>* javaGeneratedArgumentClasses) const { 241 string name; 242 if (scalar) { 243 name = mJavaArgumentsClassName; 244 } else { 245 name = mJavaArgumentsNClassName; 246 } 247 248 // Make sure we have not generated the argument class already. 249 if (!testAndSet(name, javaGeneratedArgumentClasses)) { 250 mJava->indent() << "public class " << name; 251 mJava->startBlock(); 252 253 for (auto p : mAllInputsAndOutputs) { 254 bool isFieldArray = !scalar && p->mVectorSize != "1"; 255 bool isFloatyField = p->isOutParameter && p->isFloatType && mPermutation.getTest() != "custom"; 256 257 mJava->indent() << "public "; 258 if (isFloatyField) { 259 *mJava << "Target.Floaty"; 260 } else { 261 *mJava << p->javaBaseType; 262 } 263 if (isFieldArray) { 264 *mJava << "[]"; 265 } 266 *mJava << " " << p->variableName << ";\n"; 267 268 // For Float16 parameters, add an extra 'double' field in the class 269 // to hold the Double value converted from the input. 270 if (p->isFloat16Parameter() && !isFloatyField) { 271 mJava->indent() << "public double"; 272 if (isFieldArray) { 273 *mJava << "[]"; 274 } 275 *mJava << " " + p->variableName << "Double;\n"; 276 } 277 } 278 mJava->endBlock(); 279 *mJava << "\n"; 280 } 281} 282 283void PermutationWriter::writeJavaCheckMethod(bool generateCallToVerifier) const { 284 mJava->indent() << "private void " << mJavaCheckMethodName << "()"; 285 mJava->startBlock(); 286 287 // Generate the input allocations and initialization. 288 for (auto p : mAllInputsAndOutputs) { 289 if (!p->isOutParameter) { 290 writeJavaInputAllocationDefinition(*p); 291 } 292 } 293 // Generate code to enforce ordering between two allocations if needed. 294 for (auto p : mAllInputsAndOutputs) { 295 if (!p->isOutParameter && !p->smallerParameter.empty()) { 296 string smallerAlloc = "in" + capitalize(p->smallerParameter); 297 mJava->indent() << "enforceOrdering(" << smallerAlloc << ", " << p->javaAllocName 298 << ");\n"; 299 } 300 } 301 302 // Generate code to check the full and relaxed scripts. 303 writeJavaCallToRs(false, generateCallToVerifier); 304 writeJavaCallToRs(true, generateCallToVerifier); 305 306 // Generate code to destroy input Allocations. 307 for (auto p : mAllInputsAndOutputs) { 308 if (!p->isOutParameter) { 309 mJava->indent() << p->javaAllocName << ".destroy();\n"; 310 } 311 } 312 313 mJava->endBlock(); 314 *mJava << "\n"; 315} 316 317void PermutationWriter::writeJavaInputAllocationDefinition(const ParameterDefinition& param) const { 318 string dataType; 319 char vectorSize; 320 convertToRsType(param.rsType, &dataType, &vectorSize); 321 322 const string seed = hashString(mJavaCheckMethodName + param.javaAllocName); 323 mJava->indent() << "Allocation " << param.javaAllocName << " = "; 324 if (param.compatibleTypeIndex >= 0) { 325 if (TYPES[param.typeIndex].kind == FLOATING_POINT) { 326 writeJavaRandomCompatibleFloatAllocation(dataType, seed, vectorSize, 327 TYPES[param.compatibleTypeIndex], 328 TYPES[param.typeIndex]); 329 } else { 330 writeJavaRandomCompatibleIntegerAllocation(dataType, seed, vectorSize, 331 TYPES[param.compatibleTypeIndex], 332 TYPES[param.typeIndex]); 333 } 334 } else if (!param.minValue.empty()) { 335 *mJava << "createRandomFloatAllocation(mRS, Element.DataType." << dataType << ", " 336 << vectorSize << ", " << seed << ", " << param.minValue << ", " << param.maxValue 337 << ")"; 338 } else { 339 /* TODO Instead of passing always false, check whether we are doing a limited test. 340 * Use instead: (mPermutation.getTest() == "limited" ? "false" : "true") 341 */ 342 *mJava << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize 343 << ", " << seed << ", false)"; 344 } 345 *mJava << ";\n"; 346} 347 348void PermutationWriter::writeJavaRandomCompatibleFloatAllocation( 349 const string& dataType, const string& seed, char vectorSize, 350 const NumericalType& compatibleType, const NumericalType& generatedType) const { 351 *mJava << "createRandomFloatAllocation" 352 << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", "; 353 double minValue = 0.0; 354 double maxValue = 0.0; 355 switch (compatibleType.kind) { 356 case FLOATING_POINT: { 357 // We're generating floating point values. We just worry about the exponent. 358 // Subtract 1 for the exponent sign. 359 int bits = min(compatibleType.exponentBits, generatedType.exponentBits) - 1; 360 maxValue = ldexp(0.95, (1 << bits) - 1); 361 minValue = -maxValue; 362 break; 363 } 364 case UNSIGNED_INTEGER: 365 maxValue = maxDoubleForInteger(compatibleType.significantBits, 366 generatedType.significantBits); 367 minValue = 0.0; 368 break; 369 case SIGNED_INTEGER: 370 maxValue = maxDoubleForInteger(compatibleType.significantBits, 371 generatedType.significantBits); 372 minValue = -maxValue - 1.0; 373 break; 374 } 375 *mJava << scientific << std::setprecision(19); 376 *mJava << minValue << ", " << maxValue << ")"; 377 mJava->unsetf(ios_base::floatfield); 378} 379 380void PermutationWriter::writeJavaRandomCompatibleIntegerAllocation( 381 const string& dataType, const string& seed, char vectorSize, 382 const NumericalType& compatibleType, const NumericalType& generatedType) const { 383 *mJava << "createRandomIntegerAllocation" 384 << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", "; 385 386 if (compatibleType.kind == FLOATING_POINT) { 387 // Currently, all floating points can take any number we generate. 388 bool isSigned = generatedType.kind == SIGNED_INTEGER; 389 *mJava << (isSigned ? "true" : "false") << ", " << generatedType.significantBits; 390 } else { 391 bool isSigned = 392 compatibleType.kind == SIGNED_INTEGER && generatedType.kind == SIGNED_INTEGER; 393 *mJava << (isSigned ? "true" : "false") << ", " 394 << min(compatibleType.significantBits, generatedType.significantBits); 395 } 396 *mJava << ")"; 397} 398 399void PermutationWriter::writeJavaOutputAllocationDefinition( 400 const ParameterDefinition& param) const { 401 string dataType; 402 char vectorSize; 403 convertToRsType(param.rsType, &dataType, &vectorSize); 404 mJava->indent() << "Allocation " << param.javaAllocName << " = Allocation.createSized(mRS, " 405 << "getElement(mRS, Element.DataType." << dataType << ", " << vectorSize 406 << "), INPUTSIZE);\n"; 407} 408 409void PermutationWriter::writeJavaVerifyScalarMethod(bool verifierValidates) const { 410 writeJavaVerifyMethodHeader(); 411 mJava->startBlock(); 412 413 string vectorSize = "1"; 414 for (auto p : mAllInputsAndOutputs) { 415 writeJavaArrayInitialization(*p); 416 if (p->mVectorSize != "1" && p->mVectorSize != vectorSize) { 417 if (vectorSize == "1") { 418 vectorSize = p->mVectorSize; 419 } else { 420 cerr << "Error. Had vector " << vectorSize << " and " << p->mVectorSize << "\n"; 421 } 422 } 423 } 424 425 mJava->indent() << "StringBuilder message = new StringBuilder();\n"; 426 mJava->indent() << "boolean errorFound = false;\n"; 427 mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)"; 428 mJava->startBlock(); 429 430 mJava->indent() << "for (int j = 0; j < " << vectorSize << " ; j++)"; 431 mJava->startBlock(); 432 433 mJava->indent() << "// Extract the inputs.\n"; 434 mJava->indent() << mJavaArgumentsClassName << " args = new " << mJavaArgumentsClassName 435 << "();\n"; 436 for (auto p : mAllInputsAndOutputs) { 437 if (!p->isOutParameter) { 438 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i"; 439 if (p->vectorWidth != "1") { 440 *mJava << " * " << p->vectorWidth << " + j"; 441 } 442 *mJava << "];\n"; 443 444 // Convert the Float16 parameter to double and store it in the appropriate field in the 445 // Arguments class. 446 if (p->isFloat16Parameter()) { 447 mJava->indent() << "args." << p->doubleVariableName 448 << " = Float16Utils.convertFloat16ToDouble(args." 449 << p->variableName << ");\n"; 450 } 451 } 452 } 453 const bool hasFloat = mPermutation.hasFloatAnswers(); 454 if (verifierValidates) { 455 mJava->indent() << "// Extract the outputs.\n"; 456 for (auto p : mAllInputsAndOutputs) { 457 if (p->isOutParameter) { 458 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName 459 << "[i * " << p->vectorWidth << " + j];\n"; 460 if (p->isFloat16Parameter()) { 461 mJava->indent() << "args." << p->doubleVariableName 462 << " = Float16Utils.convertFloat16ToDouble(args." 463 << p->variableName << ");\n"; 464 } 465 } 466 } 467 mJava->indent() << "// Ask the CoreMathVerifier to validate.\n"; 468 if (hasFloat) { 469 writeJavaCreateTarget(); 470 } 471 mJava->indent() << "String errorMessage = CoreMathVerifier." 472 << mJavaVerifierVerifyMethodName << "(args"; 473 if (hasFloat) { 474 *mJava << ", target"; 475 } 476 *mJava << ");\n"; 477 mJava->indent() << "boolean valid = errorMessage == null;\n"; 478 } else { 479 mJava->indent() << "// Figure out what the outputs should have been.\n"; 480 if (hasFloat) { 481 writeJavaCreateTarget(); 482 } 483 mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args"; 484 if (hasFloat) { 485 *mJava << ", target"; 486 } 487 *mJava << ");\n"; 488 mJava->indent() << "// Validate the outputs.\n"; 489 mJava->indent() << "boolean valid = true;\n"; 490 for (auto p : mAllInputsAndOutputs) { 491 if (p->isOutParameter) { 492 writeJavaTestAndSetValid(*p, "", "[i * " + p->vectorWidth + " + j]"); 493 } 494 } 495 } 496 497 mJava->indent() << "if (!valid)"; 498 mJava->startBlock(); 499 mJava->indent() << "if (!errorFound)"; 500 mJava->startBlock(); 501 mJava->indent() << "errorFound = true;\n"; 502 503 for (auto p : mAllInputsAndOutputs) { 504 if (p->isOutParameter) { 505 writeJavaAppendOutputToMessage(*p, "", "[i * " + p->vectorWidth + " + j]", 506 verifierValidates); 507 } else { 508 writeJavaAppendInputToMessage(*p, "args." + p->variableName); 509 } 510 } 511 if (verifierValidates) { 512 mJava->indent() << "message.append(errorMessage);\n"; 513 } 514 mJava->indent() << "message.append(\"Errors at\");\n"; 515 mJava->endBlock(); 516 517 mJava->indent() << "message.append(\" [\");\n"; 518 mJava->indent() << "message.append(Integer.toString(i));\n"; 519 mJava->indent() << "message.append(\", \");\n"; 520 mJava->indent() << "message.append(Integer.toString(j));\n"; 521 mJava->indent() << "message.append(\"]\");\n"; 522 523 mJava->endBlock(); 524 mJava->endBlock(); 525 mJava->endBlock(); 526 527 mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n"; 528 mJava->indentPlus() 529 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n"; 530 531 mJava->endBlock(); 532 *mJava << "\n"; 533} 534 535void PermutationWriter::writeJavaVerifyVectorMethod() const { 536 writeJavaVerifyMethodHeader(); 537 mJava->startBlock(); 538 539 for (auto p : mAllInputsAndOutputs) { 540 writeJavaArrayInitialization(*p); 541 } 542 mJava->indent() << "StringBuilder message = new StringBuilder();\n"; 543 mJava->indent() << "boolean errorFound = false;\n"; 544 mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)"; 545 mJava->startBlock(); 546 547 mJava->indent() << mJavaArgumentsNClassName << " args = new " << mJavaArgumentsNClassName 548 << "();\n"; 549 550 mJava->indent() << "// Create the appropriate sized arrays in args\n"; 551 for (auto p : mAllInputsAndOutputs) { 552 if (p->mVectorSize != "1") { 553 string type = p->javaBaseType; 554 if (p->isOutParameter && p->isFloatType) { 555 type = "Target.Floaty"; 556 } 557 mJava->indent() << "args." << p->variableName << " = new " << type << "[" 558 << p->mVectorSize << "];\n"; 559 if (p->isFloat16Parameter() && !p->isOutParameter) { 560 mJava->indent() << "args." << p->variableName << "Double = new double[" 561 << p->mVectorSize << "];\n"; 562 } 563 } 564 } 565 566 mJava->indent() << "// Fill args with the input values\n"; 567 for (auto p : mAllInputsAndOutputs) { 568 if (!p->isOutParameter) { 569 if (p->mVectorSize == "1") { 570 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i]" 571 << ";\n"; 572 // Convert the Float16 parameter to double and store it in the appropriate field in 573 // the Arguments class. 574 if (p->isFloat16Parameter()) { 575 mJava->indent() << "args." << p->doubleVariableName << " = " 576 << "Float16Utils.convertFloat16ToDouble(args." 577 << p->variableName << ");\n"; 578 } 579 } else { 580 mJava->indent() << "for (int j = 0; j < " << p->mVectorSize << " ; j++)"; 581 mJava->startBlock(); 582 mJava->indent() << "args." << p->variableName << "[j] = " 583 << p->javaArrayName << "[i * " << p->vectorWidth << " + j]" 584 << ";\n"; 585 586 // Convert the Float16 parameter to double and store it in the appropriate field in 587 // the Arguments class. 588 if (p->isFloat16Parameter()) { 589 mJava->indent() << "args." << p->doubleVariableName << "[j] = " 590 << "Float16Utils.convertFloat16ToDouble(args." 591 << p->variableName << "[j]);\n"; 592 } 593 mJava->endBlock(); 594 } 595 } 596 } 597 writeJavaCreateTarget(); 598 mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName 599 << "(args, target);\n\n"; 600 601 mJava->indent() << "// Compare the expected outputs to the actual values returned by RS.\n"; 602 mJava->indent() << "boolean valid = true;\n"; 603 for (auto p : mAllInputsAndOutputs) { 604 if (p->isOutParameter) { 605 writeJavaVectorComparison(*p); 606 } 607 } 608 609 mJava->indent() << "if (!valid)"; 610 mJava->startBlock(); 611 mJava->indent() << "if (!errorFound)"; 612 mJava->startBlock(); 613 mJava->indent() << "errorFound = true;\n"; 614 615 for (auto p : mAllInputsAndOutputs) { 616 if (p->isOutParameter) { 617 writeJavaAppendVectorOutputToMessage(*p); 618 } else { 619 writeJavaAppendVectorInputToMessage(*p); 620 } 621 } 622 mJava->indent() << "message.append(\"Errors at\");\n"; 623 mJava->endBlock(); 624 625 mJava->indent() << "message.append(\" [\");\n"; 626 mJava->indent() << "message.append(Integer.toString(i));\n"; 627 mJava->indent() << "message.append(\"]\");\n"; 628 629 mJava->endBlock(); 630 mJava->endBlock(); 631 632 mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n"; 633 mJava->indentPlus() 634 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n"; 635 636 mJava->endBlock(); 637 *mJava << "\n"; 638} 639 640 641void PermutationWriter::writeJavaCreateTarget() const { 642 string name = mPermutation.getName(); 643 644 const char* functionType = "NORMAL"; 645 size_t end = name.find('_'); 646 if (end != string::npos) { 647 if (name.compare(0, end, "native") == 0) { 648 functionType = "NATIVE"; 649 } else if (name.compare(0, end, "half") == 0) { 650 functionType = "HALF"; 651 } else if (name.compare(0, end, "fast") == 0) { 652 functionType = "FAST"; 653 } 654 } 655 656 string floatType = mReturnParam->specType; 657 const char* precisionStr = ""; 658 if (floatType.compare("f16") == 0) { 659 precisionStr = "HALF"; 660 } else if (floatType.compare("f32") == 0) { 661 precisionStr = "FLOAT"; 662 } else if (floatType.compare("f64") == 0) { 663 precisionStr = "DOUBLE"; 664 } else { 665 cerr << "Error. Unreachable. Return type is not floating point\n"; 666 } 667 668 mJava->indent() << "Target target = new Target(Target.FunctionType." << 669 functionType << ", Target.ReturnType." << precisionStr << 670 ", relaxed);\n"; 671} 672 673void PermutationWriter::writeJavaVerifyMethodHeader() const { 674 mJava->indent() << "private void " << mJavaVerifyMethodName << "("; 675 for (auto p : mAllInputsAndOutputs) { 676 *mJava << "Allocation " << p->javaAllocName << ", "; 677 } 678 *mJava << "boolean relaxed)"; 679} 680 681void PermutationWriter::writeJavaArrayInitialization(const ParameterDefinition& p) const { 682 mJava->indent() << p.javaBaseType << "[] " << p.javaArrayName << " = new " << p.javaBaseType 683 << "[INPUTSIZE * " << p.vectorWidth << "];\n"; 684 685 /* For basic types, populate the array with values, to help understand failures. We have had 686 * bugs where the output buffer was all 0. We were not sure if there was a failed copy or 687 * the GPU driver was copying zeroes. 688 */ 689 if (p.typeIndex >= 0) { 690 mJava->indent() << "Arrays.fill(" << p.javaArrayName << ", (" << TYPES[p.typeIndex].javaType 691 << ") 42);\n"; 692 } 693 694 mJava->indent() << p.javaAllocName << ".copyTo(" << p.javaArrayName << ");\n"; 695} 696 697void PermutationWriter::writeJavaTestAndSetValid(const ParameterDefinition& p, 698 const string& argsIndex, 699 const string& actualIndex) const { 700 writeJavaTestOneValue(p, argsIndex, actualIndex); 701 mJava->startBlock(); 702 mJava->indent() << "valid = false;\n"; 703 mJava->endBlock(); 704} 705 706void PermutationWriter::writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex, 707 const string& actualIndex) const { 708 string actualOut; 709 if (p.isFloat16Parameter()) { 710 // For Float16 values, the output needs to be converted to Double. 711 actualOut = "Float16Utils.convertFloat16ToDouble(" + p.javaArrayName + actualIndex + ")"; 712 } else { 713 actualOut = p.javaArrayName + actualIndex; 714 } 715 716 mJava->indent() << "if ("; 717 if (p.isFloatType) { 718 *mJava << "!args." << p.variableName << argsIndex << ".couldBe(" << actualOut; 719 const string s = mPermutation.getPrecisionLimit(); 720 if (!s.empty()) { 721 *mJava << ", " << s; 722 } 723 *mJava << ")"; 724 } else { 725 *mJava << "args." << p.variableName << argsIndex << " != " << p.javaArrayName 726 << actualIndex; 727 } 728 729 if (p.undefinedIfOutIsNan && mReturnParam) { 730 *mJava << " && !args." << mReturnParam->variableName << argsIndex << ".isNaN()"; 731 } 732 *mJava << ")"; 733} 734 735void PermutationWriter::writeJavaVectorComparison(const ParameterDefinition& p) const { 736 if (p.mVectorSize == "1") { 737 writeJavaTestAndSetValid(p, "", "[i]"); 738 } else { 739 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)"; 740 mJava->startBlock(); 741 writeJavaTestAndSetValid(p, "[j]", "[i * " + p.vectorWidth + " + j]"); 742 mJava->endBlock(); 743 } 744} 745 746void PermutationWriter::writeJavaAppendOutputToMessage(const ParameterDefinition& p, 747 const string& argsIndex, 748 const string& actualIndex, 749 bool verifierValidates) const { 750 if (verifierValidates) { 751 mJava->indent() << "message.append(\"Output " << p.variableName << ": \");\n"; 752 mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex 753 << ");\n"; 754 writeJavaAppendNewLineToMessage(); 755 if (p.isFloat16Parameter()) { 756 writeJavaAppendNewLineToMessage(); 757 mJava->indent() << "message.append(\"Output " << p.variableName 758 << " (in double): \");\n"; 759 mJava->indent() << "appendVariableToMessage(message, args." << p.doubleVariableName 760 << ");\n"; 761 writeJavaAppendNewLineToMessage(); 762 } 763 } else { 764 mJava->indent() << "message.append(\"Expected output " << p.variableName << ": \");\n"; 765 mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex 766 << ");\n"; 767 writeJavaAppendNewLineToMessage(); 768 769 mJava->indent() << "message.append(\"Actual output " << p.variableName << ": \");\n"; 770 mJava->indent() << "appendVariableToMessage(message, " << p.javaArrayName << actualIndex 771 << ");\n"; 772 773 if (p.isFloat16Parameter()) { 774 writeJavaAppendNewLineToMessage(); 775 mJava->indent() << "message.append(\"Actual output " << p.variableName 776 << " (in double): \");\n"; 777 mJava->indent() << "appendVariableToMessage(message, Float16Utils.convertFloat16ToDouble(" 778 << p.javaArrayName << actualIndex << "));\n"; 779 } 780 781 writeJavaTestOneValue(p, argsIndex, actualIndex); 782 mJava->startBlock(); 783 mJava->indent() << "message.append(\" FAIL\");\n"; 784 mJava->endBlock(); 785 writeJavaAppendNewLineToMessage(); 786 } 787} 788 789void PermutationWriter::writeJavaAppendInputToMessage(const ParameterDefinition& p, 790 const string& actual) const { 791 mJava->indent() << "message.append(\"Input " << p.variableName << ": \");\n"; 792 mJava->indent() << "appendVariableToMessage(message, " << actual << ");\n"; 793 writeJavaAppendNewLineToMessage(); 794} 795 796void PermutationWriter::writeJavaAppendNewLineToMessage() const { 797 mJava->indent() << "message.append(\"\\n\");\n"; 798} 799 800void PermutationWriter::writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const { 801 if (p.mVectorSize == "1") { 802 writeJavaAppendInputToMessage(p, p.javaArrayName + "[i]"); 803 } else { 804 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)"; 805 mJava->startBlock(); 806 writeJavaAppendInputToMessage(p, p.javaArrayName + "[i * " + p.vectorWidth + " + j]"); 807 mJava->endBlock(); 808 } 809} 810 811void PermutationWriter::writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const { 812 if (p.mVectorSize == "1") { 813 writeJavaAppendOutputToMessage(p, "", "[i]", false); 814 } else { 815 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)"; 816 mJava->startBlock(); 817 writeJavaAppendOutputToMessage(p, "[j]", "[i * " + p.vectorWidth + " + j]", false); 818 mJava->endBlock(); 819 } 820} 821 822void PermutationWriter::writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const { 823 string script = "script"; 824 if (relaxed) { 825 script += "Relaxed"; 826 } 827 828 mJava->indent() << "try"; 829 mJava->startBlock(); 830 831 for (auto p : mAllInputsAndOutputs) { 832 if (p->isOutParameter) { 833 writeJavaOutputAllocationDefinition(*p); 834 } 835 } 836 837 for (auto p : mPermutation.getParams()) { 838 if (p != mFirstInputParam) { 839 mJava->indent() << script << ".set_" << p->rsAllocName << "(" << p->javaAllocName 840 << ");\n"; 841 } 842 } 843 844 mJava->indent() << script << ".forEach_" << mRsKernelName << "("; 845 bool needComma = false; 846 if (mFirstInputParam) { 847 *mJava << mFirstInputParam->javaAllocName; 848 needComma = true; 849 } 850 if (mReturnParam) { 851 if (needComma) { 852 *mJava << ", "; 853 } 854 *mJava << mReturnParam->variableName << ");\n"; 855 } 856 857 if (generateCallToVerifier) { 858 mJava->indent() << mJavaVerifyMethodName << "("; 859 for (auto p : mAllInputsAndOutputs) { 860 *mJava << p->variableName << ", "; 861 } 862 863 if (relaxed) { 864 *mJava << "true"; 865 } else { 866 *mJava << "false"; 867 } 868 *mJava << ");\n"; 869 } 870 871 // Generate code to destroy output Allocations. 872 for (auto p : mAllInputsAndOutputs) { 873 if (p->isOutParameter) { 874 mJava->indent() << p->javaAllocName << ".destroy();\n"; 875 } 876 } 877 878 mJava->decreaseIndent(); 879 mJava->indent() << "} catch (Exception e) {\n"; 880 mJava->increaseIndent(); 881 mJava->indent() << "throw new RSRuntimeException(\"RenderScript. Can't invoke forEach_" 882 << mRsKernelName << ": \" + e.toString());\n"; 883 mJava->endBlock(); 884} 885 886/* Write the section of the .rs file for this permutation. 887 * 888 * We communicate the extra input and output parameters via global allocations. 889 * For example, if we have a function that takes three arguments, two for input 890 * and one for output: 891 * 892 * start: 893 * name: gamn 894 * ret: float3 895 * arg: float3 a 896 * arg: int b 897 * arg: float3 *c 898 * end: 899 * 900 * We'll produce: 901 * 902 * rs_allocation gAllocInB; 903 * rs_allocation gAllocOutC; 904 * 905 * float3 __attribute__((kernel)) test_gamn_float3_int_float3(float3 inA, unsigned int x) { 906 * int inB; 907 * float3 outC; 908 * float2 out; 909 * inB = rsGetElementAt_int(gAllocInB, x); 910 * out = gamn(a, in_b, &outC); 911 * rsSetElementAt_float4(gAllocOutC, &outC, x); 912 * return out; 913 * } 914 * 915 * We avoid re-using x and y from the definition because these have reserved 916 * meanings in a .rs file. 917 */ 918void PermutationWriter::writeRsSection(set<string>* rsAllocationsGenerated) const { 919 // Write the allocation declarations we'll need. 920 for (auto p : mPermutation.getParams()) { 921 // Don't need allocation for one input and one return value. 922 if (p != mFirstInputParam) { 923 writeRsAllocationDefinition(*p, rsAllocationsGenerated); 924 } 925 } 926 *mRs << "\n"; 927 928 // Write the function header. 929 if (mReturnParam) { 930 *mRs << mReturnParam->rsType; 931 } else { 932 *mRs << "void"; 933 } 934 *mRs << " __attribute__((kernel)) " << mRsKernelName; 935 *mRs << "("; 936 bool needComma = false; 937 if (mFirstInputParam) { 938 *mRs << mFirstInputParam->rsType << " " << mFirstInputParam->variableName; 939 needComma = true; 940 } 941 if (mPermutation.getOutputCount() > 1 || mPermutation.getInputCount() > 1) { 942 if (needComma) { 943 *mRs << ", "; 944 } 945 *mRs << "unsigned int x"; 946 } 947 *mRs << ")"; 948 mRs->startBlock(); 949 950 // Write the local variable declarations and initializations. 951 for (auto p : mPermutation.getParams()) { 952 if (p == mFirstInputParam) { 953 continue; 954 } 955 mRs->indent() << p->rsType << " " << p->variableName; 956 if (p->isOutParameter) { 957 *mRs << " = 0;\n"; 958 } else { 959 *mRs << " = rsGetElementAt_" << p->rsType << "(" << p->rsAllocName << ", x);\n"; 960 } 961 } 962 963 // Write the function call. 964 if (mReturnParam) { 965 if (mPermutation.getOutputCount() > 1) { 966 mRs->indent() << mReturnParam->rsType << " " << mReturnParam->variableName << " = "; 967 } else { 968 mRs->indent() << "return "; 969 } 970 } 971 *mRs << mPermutation.getName() << "("; 972 needComma = false; 973 for (auto p : mPermutation.getParams()) { 974 if (needComma) { 975 *mRs << ", "; 976 } 977 if (p->isOutParameter) { 978 *mRs << "&"; 979 } 980 *mRs << p->variableName; 981 needComma = true; 982 } 983 *mRs << ");\n"; 984 985 if (mPermutation.getOutputCount() > 1) { 986 // Write setting the extra out parameters into the allocations. 987 for (auto p : mPermutation.getParams()) { 988 if (p->isOutParameter) { 989 mRs->indent() << "rsSetElementAt_" << p->rsType << "(" << p->rsAllocName << ", "; 990 // Check if we need to use '&' for this type of argument. 991 char lastChar = p->variableName.back(); 992 if (lastChar >= '0' && lastChar <= '9') { 993 *mRs << "&"; 994 } 995 *mRs << p->variableName << ", x);\n"; 996 } 997 } 998 if (mReturnParam) { 999 mRs->indent() << "return " << mReturnParam->variableName << ";\n"; 1000 } 1001 } 1002 mRs->endBlock(); 1003} 1004 1005void PermutationWriter::writeRsAllocationDefinition(const ParameterDefinition& param, 1006 set<string>* rsAllocationsGenerated) const { 1007 if (!testAndSet(param.rsAllocName, rsAllocationsGenerated)) { 1008 *mRs << "rs_allocation " << param.rsAllocName << ";\n"; 1009 } 1010} 1011 1012// Open the mJavaFile and writes the header. 1013static bool startJavaFile(GeneratedFile* file, const string& directory, 1014 const string& testName, 1015 const string& relaxedTestName) { 1016 const string fileName = testName + ".java"; 1017 if (!file->start(directory, fileName)) { 1018 return false; 1019 } 1020 file->writeNotices(); 1021 1022 *file << "package android.renderscript.cts;\n\n"; 1023 1024 *file << "import android.renderscript.Allocation;\n"; 1025 *file << "import android.renderscript.RSRuntimeException;\n"; 1026 *file << "import android.renderscript.Element;\n"; 1027 *file << "import android.renderscript.cts.Target;\n\n"; 1028 *file << "import java.util.Arrays;\n\n"; 1029 1030 *file << "public class " << testName << " extends RSBaseCompute"; 1031 file->startBlock(); // The corresponding endBlock() is in finishJavaFile() 1032 *file << "\n"; 1033 1034 file->indent() << "private ScriptC_" << testName << " script;\n"; 1035 file->indent() << "private ScriptC_" << relaxedTestName << " scriptRelaxed;\n\n"; 1036 1037 file->indent() << "@Override\n"; 1038 file->indent() << "protected void setUp() throws Exception"; 1039 file->startBlock(); 1040 1041 file->indent() << "super.setUp();\n"; 1042 file->indent() << "script = new ScriptC_" << testName << "(mRS);\n"; 1043 file->indent() << "scriptRelaxed = new ScriptC_" << relaxedTestName << "(mRS);\n"; 1044 1045 file->endBlock(); 1046 *file << "\n"; 1047 1048 file->indent() << "@Override\n"; 1049 file->indent() << "protected void tearDown() throws Exception"; 1050 file->startBlock(); 1051 1052 file->indent() << "script.destroy();\n"; 1053 file->indent() << "scriptRelaxed.destroy();\n"; 1054 file->indent() << "super.tearDown();\n"; 1055 1056 file->endBlock(); 1057 *file << "\n"; 1058 1059 return true; 1060} 1061 1062// Write the test method that calls all the generated Check methods. 1063static void finishJavaFile(GeneratedFile* file, const Function& function, 1064 const vector<string>& javaCheckMethods) { 1065 file->indent() << "public void test" << function.getCapitalizedName() << "()"; 1066 file->startBlock(); 1067 for (auto m : javaCheckMethods) { 1068 file->indent() << m << "();\n"; 1069 } 1070 file->endBlock(); 1071 1072 file->endBlock(); 1073} 1074 1075// Open the script file and write its header. 1076static bool startRsFile(GeneratedFile* file, const string& directory, 1077 const string& testName) { 1078 string fileName = testName + ".rs"; 1079 if (!file->start(directory, fileName)) { 1080 return false; 1081 } 1082 file->writeNotices(); 1083 1084 *file << "#pragma version(1)\n"; 1085 *file << "#pragma rs java_package_name(android.renderscript.cts)\n\n"; 1086 return true; 1087} 1088 1089// Write the entire *Relaxed.rs test file, as it only depends on the name. 1090static bool writeRelaxedRsFile(const string& directory, const string& testName, 1091 const string& relaxedTestName) { 1092 string name = relaxedTestName + ".rs"; 1093 1094 GeneratedFile file; 1095 if (!file.start(directory, name)) { 1096 return false; 1097 } 1098 file.writeNotices(); 1099 1100 file << "#include \"" << testName << ".rs\"\n"; 1101 file << "#pragma rs_fp_relaxed\n"; 1102 file.close(); 1103 return true; 1104} 1105 1106/* Write the .java and the two .rs test files. versionOfTestFiles is used to restrict which API 1107 * to test. 1108 */ 1109static bool writeTestFilesForFunction(const Function& function, const string& directory, 1110 unsigned int versionOfTestFiles) { 1111 // Avoid creating empty files if we're not testing this function. 1112 if (!needTestFiles(function, versionOfTestFiles)) { 1113 return true; 1114 } 1115 1116 const string testName = "Test" + function.getCapitalizedName(); 1117 const string relaxedTestName = testName + "Relaxed"; 1118 1119 if (!writeRelaxedRsFile(directory, testName, relaxedTestName)) { 1120 return false; 1121 } 1122 1123 GeneratedFile rsFile; // The Renderscript test file we're generating. 1124 GeneratedFile javaFile; // The Jave test file we're generating. 1125 if (!startRsFile(&rsFile, directory, testName)) { 1126 return false; 1127 } 1128 1129 if (!startJavaFile(&javaFile, directory, testName, relaxedTestName)) { 1130 return false; 1131 } 1132 1133 /* We keep track of the allocations generated in the .rs file and the argument classes defined 1134 * in the Java file, as we share these between the functions created for each specification. 1135 */ 1136 set<string> rsAllocationsGenerated; 1137 set<string> javaGeneratedArgumentClasses; 1138 // Lines of Java code to invoke the check methods. 1139 vector<string> javaCheckMethods; 1140 1141 for (auto spec : function.getSpecifications()) { 1142 if (spec->hasTests(versionOfTestFiles)) { 1143 for (auto permutation : spec->getPermutations()) { 1144 PermutationWriter w(*permutation, &rsFile, &javaFile); 1145 w.writeRsSection(&rsAllocationsGenerated); 1146 w.writeJavaSection(&javaGeneratedArgumentClasses); 1147 1148 // Store the check method to be called. 1149 javaCheckMethods.push_back(w.getJavaCheckMethodName()); 1150 } 1151 } 1152 } 1153 1154 finishJavaFile(&javaFile, function, javaCheckMethods); 1155 // There's no work to wrap-up in the .rs file. 1156 1157 rsFile.close(); 1158 javaFile.close(); 1159 return true; 1160} 1161 1162bool generateTestFiles(const string& directory, unsigned int versionOfTestFiles) { 1163 bool success = true; 1164 for (auto f : systemSpecification.getFunctions()) { 1165 if (!writeTestFilesForFunction(*f.second, directory, versionOfTestFiles)) { 1166 success = false; 1167 } 1168 } 1169 return success; 1170} 1171