1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <iomanip>
18#include <iostream>
19#include <cmath>
20#include <sstream>
21
22#include "Generator.h"
23#include "Specification.h"
24#include "Utilities.h"
25
26using namespace std;
27
28// Converts float2 to FLOAT_32 and 2, etc.
29static void convertToRsType(const string& name, string* dataType, char* vectorSize) {
30    string s = name;
31    int last = s.size() - 1;
32    char lastChar = s[last];
33    if (lastChar >= '1' && lastChar <= '4') {
34        s.erase(last);
35        *vectorSize = lastChar;
36    } else {
37        *vectorSize = '1';
38    }
39    dataType->clear();
40    for (int i = 0; i < NUM_TYPES; i++) {
41        if (s == TYPES[i].cType) {
42            *dataType = TYPES[i].rsDataType;
43            break;
44        }
45    }
46}
47
48// Returns true if any permutation of the function have tests to b
49static bool needTestFiles(const Function& function, unsigned int versionOfTestFiles) {
50    for (auto spec : function.getSpecifications()) {
51        if (spec->hasTests(versionOfTestFiles)) {
52            return true;
53        }
54    }
55    return false;
56}
57
58/* One instance of this class is generated for each permutation of a function for which
59 * we are generating test code.  This instance will generate both the script and the Java
60 * section of the test files for this permutation.  The class is mostly used to keep track
61 * of the various names shared between script and Java files.
62 * WARNING: Because the constructor keeps a reference to the FunctionPermutation, PermutationWriter
63 * should not exceed the lifetime of FunctionPermutation.
64 */
65class PermutationWriter {
66private:
67    FunctionPermutation& mPermutation;
68
69    string mRsKernelName;
70    string mJavaArgumentsClassName;
71    string mJavaArgumentsNClassName;
72    string mJavaVerifierComputeMethodName;
73    string mJavaVerifierVerifyMethodName;
74    string mJavaCheckMethodName;
75    string mJavaVerifyMethodName;
76
77    // Pointer to the files we are generating.  Handy to avoid always passing them in the calls.
78    GeneratedFile* mRs;
79    GeneratedFile* mJava;
80
81    /* Shortcuts to the return parameter and the first input parameter of the function
82     * specification.
83     */
84    const ParameterDefinition* mReturnParam;      // Can be nullptr.  NOT OWNED.
85    const ParameterDefinition* mFirstInputParam;  // Can be nullptr.  NOT OWNED.
86
87    /* All the parameters plus the return param, if present.  Collecting them together
88     * simplifies code generation.  NOT OWNED.
89     */
90    vector<const ParameterDefinition*> mAllInputsAndOutputs;
91
92    /* We use a class to pass the arguments between the generated code and the CoreVerifier.  This
93     * method generates this class.  The set keeps track if we've generated this class already
94     * for this test file, as more than one permutation may use the same argument class.
95     */
96    void writeJavaArgumentClass(bool scalar, set<string>* javaGeneratedArgumentClasses) const;
97
98    // Generate the Check* method that invokes the script and calls the verifier.
99    void writeJavaCheckMethod(bool generateCallToVerifier) const;
100
101    // Generate code to define and randomly initialize the input allocation.
102    void writeJavaInputAllocationDefinition(const ParameterDefinition& param) const;
103
104    /* Generate code that instantiate an allocation of floats or integers and fills it with
105     * random data. This random data must be compatible with the specified type.  This is
106     * used for the convert_* tests, as converting values that don't fit yield undefined results.
107     */
108    void writeJavaRandomCompatibleFloatAllocation(const string& dataType, const string& seed,
109                                                  char vectorSize,
110                                                  const NumericalType& compatibleType,
111                                                  const NumericalType& generatedType) const;
112    void writeJavaRandomCompatibleIntegerAllocation(const string& dataType, const string& seed,
113                                                    char vectorSize,
114                                                    const NumericalType& compatibleType,
115                                                    const NumericalType& generatedType) const;
116
117    // Generate code that defines an output allocation.
118    void writeJavaOutputAllocationDefinition(const ParameterDefinition& param) const;
119
120    /* Generate the code that verifies the results for RenderScript functions where each entry
121     * of a vector is evaluated independently.  If verifierValidates is true, CoreMathVerifier
122     * does the actual validation instead of more commonly returning the range of acceptable values.
123     */
124    void writeJavaVerifyScalarMethod(bool verifierValidates) const;
125
126    /* Generate the code that verify the results for a RenderScript function where a vector
127     * is a point in n-dimensional space.
128     */
129    void writeJavaVerifyVectorMethod() const;
130
131    // Generate the line that creates the Target.
132    void writeJavaCreateTarget() const;
133
134    // Generate the method header of the verify function.
135    void writeJavaVerifyMethodHeader() const;
136
137    // Generate codes that copies the content of an allocation to an array.
138    void writeJavaArrayInitialization(const ParameterDefinition& p) const;
139
140    // Generate code that tests one value returned from the script.
141    void writeJavaTestAndSetValid(const ParameterDefinition& p, const string& argsIndex,
142                                  const string& actualIndex) const;
143    void writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
144                               const string& actualIndex) const;
145    // For test:vector cases, generate code that compares returned vector vs. expected value.
146    void writeJavaVectorComparison(const ParameterDefinition& p) const;
147
148    // Muliple functions that generates code to build the error message if an error is found.
149    void writeJavaAppendOutputToMessage(const ParameterDefinition& p, const string& argsIndex,
150                                        const string& actualIndex, bool verifierValidates) const;
151    void writeJavaAppendInputToMessage(const ParameterDefinition& p, const string& actual) const;
152    void writeJavaAppendNewLineToMessage() const;
153    void writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const;
154    void writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const;
155
156    // Generate the set of instructions to call the script.
157    void writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const;
158
159    // Write an allocation definition if not already emitted in the .rs file.
160    void writeRsAllocationDefinition(const ParameterDefinition& param,
161                                     set<string>* rsAllocationsGenerated) const;
162
163public:
164    /* NOTE: We keep pointers to the permutation and the files.  This object should not
165     * outlive the arguments.
166     */
167    PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
168                      GeneratedFile* javaFile);
169    string getJavaCheckMethodName() const { return mJavaCheckMethodName; }
170
171    // Write the script test function for this permutation.
172    void writeRsSection(set<string>* rsAllocationsGenerated) const;
173    // Write the section of the Java code that calls the script and validates the results
174    void writeJavaSection(set<string>* javaGeneratedArgumentClasses) const;
175};
176
177PermutationWriter::PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
178                                     GeneratedFile* javaFile)
179    : mPermutation(permutation),
180      mRs(rsFile),
181      mJava(javaFile),
182      mReturnParam(nullptr),
183      mFirstInputParam(nullptr) {
184    mRsKernelName = "test" + capitalize(permutation.getName());
185
186    mJavaArgumentsClassName = "Arguments";
187    mJavaArgumentsNClassName = "Arguments";
188    const string trunk = capitalize(permutation.getNameTrunk());
189    mJavaCheckMethodName = "check" + trunk;
190    mJavaVerifyMethodName = "verifyResults" + trunk;
191
192    for (auto p : permutation.getParams()) {
193        mAllInputsAndOutputs.push_back(p);
194        if (mFirstInputParam == nullptr && !p->isOutParameter) {
195            mFirstInputParam = p;
196        }
197    }
198    mReturnParam = permutation.getReturn();
199    if (mReturnParam) {
200        mAllInputsAndOutputs.push_back(mReturnParam);
201    }
202
203    for (auto p : mAllInputsAndOutputs) {
204        const string capitalizedRsType = capitalize(p->rsType);
205        const string capitalizedBaseType = capitalize(p->rsBaseType);
206        mRsKernelName += capitalizedRsType;
207        mJavaArgumentsClassName += capitalizedBaseType;
208        mJavaArgumentsNClassName += capitalizedBaseType;
209        if (p->mVectorSize != "1") {
210            mJavaArgumentsNClassName += "N";
211        }
212        mJavaCheckMethodName += capitalizedRsType;
213        mJavaVerifyMethodName += capitalizedRsType;
214    }
215    mJavaVerifierComputeMethodName = "compute" + trunk;
216    mJavaVerifierVerifyMethodName = "verify" + trunk;
217}
218
219void PermutationWriter::writeJavaSection(set<string>* javaGeneratedArgumentClasses) const {
220    // By default, we test the results using item by item comparison.
221    const string test = mPermutation.getTest();
222    if (test == "scalar" || test == "limited") {
223        writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
224        writeJavaCheckMethod(true);
225        writeJavaVerifyScalarMethod(false);
226    } else if (test == "custom") {
227        writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
228        writeJavaCheckMethod(true);
229        writeJavaVerifyScalarMethod(true);
230    } else if (test == "vector") {
231        writeJavaArgumentClass(false, javaGeneratedArgumentClasses);
232        writeJavaCheckMethod(true);
233        writeJavaVerifyVectorMethod();
234    } else if (test == "noverify") {
235        writeJavaCheckMethod(false);
236    }
237}
238
239void PermutationWriter::writeJavaArgumentClass(bool scalar,
240                                               set<string>* javaGeneratedArgumentClasses) const {
241    string name;
242    if (scalar) {
243        name = mJavaArgumentsClassName;
244    } else {
245        name = mJavaArgumentsNClassName;
246    }
247
248    // Make sure we have not generated the argument class already.
249    if (!testAndSet(name, javaGeneratedArgumentClasses)) {
250        mJava->indent() << "public class " << name;
251        mJava->startBlock();
252
253        for (auto p : mAllInputsAndOutputs) {
254            bool isFieldArray = !scalar && p->mVectorSize != "1";
255            bool isFloatyField = p->isOutParameter && p->isFloatType && mPermutation.getTest() != "custom";
256
257            mJava->indent() << "public ";
258            if (isFloatyField) {
259                *mJava << "Target.Floaty";
260            } else {
261                *mJava << p->javaBaseType;
262            }
263            if (isFieldArray) {
264                *mJava << "[]";
265            }
266            *mJava << " " << p->variableName << ";\n";
267
268            // For Float16 parameters, add an extra 'double' field in the class
269            // to hold the Double value converted from the input.
270            if (p->isFloat16Parameter() && !isFloatyField) {
271                mJava->indent() << "public double";
272                if (isFieldArray) {
273                    *mJava << "[]";
274                }
275                *mJava << " " + p->variableName << "Double;\n";
276            }
277        }
278        mJava->endBlock();
279        *mJava << "\n";
280    }
281}
282
283void PermutationWriter::writeJavaCheckMethod(bool generateCallToVerifier) const {
284    mJava->indent() << "private void " << mJavaCheckMethodName << "()";
285    mJava->startBlock();
286
287    // Generate the input allocations and initialization.
288    for (auto p : mAllInputsAndOutputs) {
289        if (!p->isOutParameter) {
290            writeJavaInputAllocationDefinition(*p);
291        }
292    }
293    // Generate code to enforce ordering between two allocations if needed.
294    for (auto p : mAllInputsAndOutputs) {
295        if (!p->isOutParameter && !p->smallerParameter.empty()) {
296            string smallerAlloc = "in" + capitalize(p->smallerParameter);
297            mJava->indent() << "enforceOrdering(" << smallerAlloc << ", " << p->javaAllocName
298                            << ");\n";
299        }
300    }
301
302    // Generate code to check the full and relaxed scripts.
303    writeJavaCallToRs(false, generateCallToVerifier);
304    writeJavaCallToRs(true, generateCallToVerifier);
305
306    // Generate code to destroy input Allocations.
307    for (auto p : mAllInputsAndOutputs) {
308        if (!p->isOutParameter) {
309            mJava->indent() << p->javaAllocName << ".destroy();\n";
310        }
311    }
312
313    mJava->endBlock();
314    *mJava << "\n";
315}
316
317void PermutationWriter::writeJavaInputAllocationDefinition(const ParameterDefinition& param) const {
318    string dataType;
319    char vectorSize;
320    convertToRsType(param.rsType, &dataType, &vectorSize);
321
322    const string seed = hashString(mJavaCheckMethodName + param.javaAllocName);
323    mJava->indent() << "Allocation " << param.javaAllocName << " = ";
324    if (param.compatibleTypeIndex >= 0) {
325        if (TYPES[param.typeIndex].kind == FLOATING_POINT) {
326            writeJavaRandomCompatibleFloatAllocation(dataType, seed, vectorSize,
327                                                     TYPES[param.compatibleTypeIndex],
328                                                     TYPES[param.typeIndex]);
329        } else {
330            writeJavaRandomCompatibleIntegerAllocation(dataType, seed, vectorSize,
331                                                       TYPES[param.compatibleTypeIndex],
332                                                       TYPES[param.typeIndex]);
333        }
334    } else if (!param.minValue.empty()) {
335        *mJava << "createRandomFloatAllocation(mRS, Element.DataType." << dataType << ", "
336               << vectorSize << ", " << seed << ", " << param.minValue << ", " << param.maxValue
337               << ")";
338    } else {
339        /* TODO Instead of passing always false, check whether we are doing a limited test.
340         * Use instead: (mPermutation.getTest() == "limited" ? "false" : "true")
341         */
342        *mJava << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize
343               << ", " << seed << ", false)";
344    }
345    *mJava << ";\n";
346}
347
348void PermutationWriter::writeJavaRandomCompatibleFloatAllocation(
349            const string& dataType, const string& seed, char vectorSize,
350            const NumericalType& compatibleType, const NumericalType& generatedType) const {
351    *mJava << "createRandomFloatAllocation"
352           << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
353    double minValue = 0.0;
354    double maxValue = 0.0;
355    switch (compatibleType.kind) {
356        case FLOATING_POINT: {
357            // We're generating floating point values.  We just worry about the exponent.
358            // Subtract 1 for the exponent sign.
359            int bits = min(compatibleType.exponentBits, generatedType.exponentBits) - 1;
360            maxValue = ldexp(0.95, (1 << bits) - 1);
361            minValue = -maxValue;
362            break;
363        }
364        case UNSIGNED_INTEGER:
365            maxValue = maxDoubleForInteger(compatibleType.significantBits,
366                                           generatedType.significantBits);
367            minValue = 0.0;
368            break;
369        case SIGNED_INTEGER:
370            maxValue = maxDoubleForInteger(compatibleType.significantBits,
371                                           generatedType.significantBits);
372            minValue = -maxValue - 1.0;
373            break;
374    }
375    *mJava << scientific << std::setprecision(19);
376    *mJava << minValue << ", " << maxValue << ")";
377    mJava->unsetf(ios_base::floatfield);
378}
379
380void PermutationWriter::writeJavaRandomCompatibleIntegerAllocation(
381            const string& dataType, const string& seed, char vectorSize,
382            const NumericalType& compatibleType, const NumericalType& generatedType) const {
383    *mJava << "createRandomIntegerAllocation"
384           << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
385
386    if (compatibleType.kind == FLOATING_POINT) {
387        // Currently, all floating points can take any number we generate.
388        bool isSigned = generatedType.kind == SIGNED_INTEGER;
389        *mJava << (isSigned ? "true" : "false") << ", " << generatedType.significantBits;
390    } else {
391        bool isSigned =
392                    compatibleType.kind == SIGNED_INTEGER && generatedType.kind == SIGNED_INTEGER;
393        *mJava << (isSigned ? "true" : "false") << ", "
394               << min(compatibleType.significantBits, generatedType.significantBits);
395    }
396    *mJava << ")";
397}
398
399void PermutationWriter::writeJavaOutputAllocationDefinition(
400            const ParameterDefinition& param) const {
401    string dataType;
402    char vectorSize;
403    convertToRsType(param.rsType, &dataType, &vectorSize);
404    mJava->indent() << "Allocation " << param.javaAllocName << " = Allocation.createSized(mRS, "
405                    << "getElement(mRS, Element.DataType." << dataType << ", " << vectorSize
406                    << "), INPUTSIZE);\n";
407}
408
409void PermutationWriter::writeJavaVerifyScalarMethod(bool verifierValidates) const {
410    writeJavaVerifyMethodHeader();
411    mJava->startBlock();
412
413    string vectorSize = "1";
414    for (auto p : mAllInputsAndOutputs) {
415        writeJavaArrayInitialization(*p);
416        if (p->mVectorSize != "1" && p->mVectorSize != vectorSize) {
417            if (vectorSize == "1") {
418                vectorSize = p->mVectorSize;
419            } else {
420                cerr << "Error.  Had vector " << vectorSize << " and " << p->mVectorSize << "\n";
421            }
422        }
423    }
424
425    mJava->indent() << "StringBuilder message = new StringBuilder();\n";
426    mJava->indent() << "boolean errorFound = false;\n";
427    mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
428    mJava->startBlock();
429
430    mJava->indent() << "for (int j = 0; j < " << vectorSize << " ; j++)";
431    mJava->startBlock();
432
433    mJava->indent() << "// Extract the inputs.\n";
434    mJava->indent() << mJavaArgumentsClassName << " args = new " << mJavaArgumentsClassName
435                    << "();\n";
436    for (auto p : mAllInputsAndOutputs) {
437        if (!p->isOutParameter) {
438            mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i";
439            if (p->vectorWidth != "1") {
440                *mJava << " * " << p->vectorWidth << " + j";
441            }
442            *mJava << "];\n";
443
444            // Convert the Float16 parameter to double and store it in the appropriate field in the
445            // Arguments class.
446            if (p->isFloat16Parameter()) {
447                mJava->indent() << "args." << p->doubleVariableName
448                                << " = Float16Utils.convertFloat16ToDouble(args."
449                                << p->variableName << ");\n";
450            }
451        }
452    }
453    const bool hasFloat = mPermutation.hasFloatAnswers();
454    if (verifierValidates) {
455        mJava->indent() << "// Extract the outputs.\n";
456        for (auto p : mAllInputsAndOutputs) {
457            if (p->isOutParameter) {
458                mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName
459                                << "[i * " << p->vectorWidth << " + j];\n";
460                if (p->isFloat16Parameter()) {
461                    mJava->indent() << "args." << p->doubleVariableName
462                                    << " = Float16Utils.convertFloat16ToDouble(args."
463                                    << p->variableName << ");\n";
464                }
465            }
466        }
467        mJava->indent() << "// Ask the CoreMathVerifier to validate.\n";
468        if (hasFloat) {
469            writeJavaCreateTarget();
470        }
471        mJava->indent() << "String errorMessage = CoreMathVerifier."
472                        << mJavaVerifierVerifyMethodName << "(args";
473        if (hasFloat) {
474            *mJava << ", target";
475        }
476        *mJava << ");\n";
477        mJava->indent() << "boolean valid = errorMessage == null;\n";
478    } else {
479        mJava->indent() << "// Figure out what the outputs should have been.\n";
480        if (hasFloat) {
481            writeJavaCreateTarget();
482        }
483        mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args";
484        if (hasFloat) {
485            *mJava << ", target";
486        }
487        *mJava << ");\n";
488        mJava->indent() << "// Validate the outputs.\n";
489        mJava->indent() << "boolean valid = true;\n";
490        for (auto p : mAllInputsAndOutputs) {
491            if (p->isOutParameter) {
492                writeJavaTestAndSetValid(*p, "", "[i * " + p->vectorWidth + " + j]");
493            }
494        }
495    }
496
497    mJava->indent() << "if (!valid)";
498    mJava->startBlock();
499    mJava->indent() << "if (!errorFound)";
500    mJava->startBlock();
501    mJava->indent() << "errorFound = true;\n";
502
503    for (auto p : mAllInputsAndOutputs) {
504        if (p->isOutParameter) {
505            writeJavaAppendOutputToMessage(*p, "", "[i * " + p->vectorWidth + " + j]",
506                                           verifierValidates);
507        } else {
508            writeJavaAppendInputToMessage(*p, "args." + p->variableName);
509        }
510    }
511    if (verifierValidates) {
512        mJava->indent() << "message.append(errorMessage);\n";
513    }
514    mJava->indent() << "message.append(\"Errors at\");\n";
515    mJava->endBlock();
516
517    mJava->indent() << "message.append(\" [\");\n";
518    mJava->indent() << "message.append(Integer.toString(i));\n";
519    mJava->indent() << "message.append(\", \");\n";
520    mJava->indent() << "message.append(Integer.toString(j));\n";
521    mJava->indent() << "message.append(\"]\");\n";
522
523    mJava->endBlock();
524    mJava->endBlock();
525    mJava->endBlock();
526
527    mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
528    mJava->indentPlus()
529                << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
530
531    mJava->endBlock();
532    *mJava << "\n";
533}
534
535void PermutationWriter::writeJavaVerifyVectorMethod() const {
536    writeJavaVerifyMethodHeader();
537    mJava->startBlock();
538
539    for (auto p : mAllInputsAndOutputs) {
540        writeJavaArrayInitialization(*p);
541    }
542    mJava->indent() << "StringBuilder message = new StringBuilder();\n";
543    mJava->indent() << "boolean errorFound = false;\n";
544    mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
545    mJava->startBlock();
546
547    mJava->indent() << mJavaArgumentsNClassName << " args = new " << mJavaArgumentsNClassName
548                    << "();\n";
549
550    mJava->indent() << "// Create the appropriate sized arrays in args\n";
551    for (auto p : mAllInputsAndOutputs) {
552        if (p->mVectorSize != "1") {
553            string type = p->javaBaseType;
554            if (p->isOutParameter && p->isFloatType) {
555                type = "Target.Floaty";
556            }
557            mJava->indent() << "args." << p->variableName << " = new " << type << "["
558                            << p->mVectorSize << "];\n";
559            if (p->isFloat16Parameter() && !p->isOutParameter) {
560                mJava->indent() << "args." << p->variableName << "Double = new double["
561                                << p->mVectorSize << "];\n";
562            }
563        }
564    }
565
566    mJava->indent() << "// Fill args with the input values\n";
567    for (auto p : mAllInputsAndOutputs) {
568        if (!p->isOutParameter) {
569            if (p->mVectorSize == "1") {
570                mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i]"
571                                << ";\n";
572                // Convert the Float16 parameter to double and store it in the appropriate field in
573                // the Arguments class.
574                if (p->isFloat16Parameter()) {
575                    mJava->indent() << "args." << p->doubleVariableName << " = "
576                                    << "Float16Utils.convertFloat16ToDouble(args."
577                                    << p->variableName << ");\n";
578                }
579            } else {
580                mJava->indent() << "for (int j = 0; j < " << p->mVectorSize << " ; j++)";
581                mJava->startBlock();
582                mJava->indent() << "args." << p->variableName << "[j] = "
583                                << p->javaArrayName << "[i * " << p->vectorWidth << " + j]"
584                                << ";\n";
585
586                // Convert the Float16 parameter to double and store it in the appropriate field in
587                // the Arguments class.
588                if (p->isFloat16Parameter()) {
589                    mJava->indent() << "args." << p->doubleVariableName << "[j] = "
590                                    << "Float16Utils.convertFloat16ToDouble(args."
591                                    << p->variableName << "[j]);\n";
592                }
593                mJava->endBlock();
594            }
595        }
596    }
597    writeJavaCreateTarget();
598    mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName
599                    << "(args, target);\n\n";
600
601    mJava->indent() << "// Compare the expected outputs to the actual values returned by RS.\n";
602    mJava->indent() << "boolean valid = true;\n";
603    for (auto p : mAllInputsAndOutputs) {
604        if (p->isOutParameter) {
605            writeJavaVectorComparison(*p);
606        }
607    }
608
609    mJava->indent() << "if (!valid)";
610    mJava->startBlock();
611    mJava->indent() << "if (!errorFound)";
612    mJava->startBlock();
613    mJava->indent() << "errorFound = true;\n";
614
615    for (auto p : mAllInputsAndOutputs) {
616        if (p->isOutParameter) {
617            writeJavaAppendVectorOutputToMessage(*p);
618        } else {
619            writeJavaAppendVectorInputToMessage(*p);
620        }
621    }
622    mJava->indent() << "message.append(\"Errors at\");\n";
623    mJava->endBlock();
624
625    mJava->indent() << "message.append(\" [\");\n";
626    mJava->indent() << "message.append(Integer.toString(i));\n";
627    mJava->indent() << "message.append(\"]\");\n";
628
629    mJava->endBlock();
630    mJava->endBlock();
631
632    mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
633    mJava->indentPlus()
634                << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
635
636    mJava->endBlock();
637    *mJava << "\n";
638}
639
640
641void PermutationWriter::writeJavaCreateTarget() const {
642    string name = mPermutation.getName();
643
644    const char* functionType = "NORMAL";
645    size_t end = name.find('_');
646    if (end != string::npos) {
647        if (name.compare(0, end, "native") == 0) {
648            functionType = "NATIVE";
649        } else if (name.compare(0, end, "half") == 0) {
650            functionType = "HALF";
651        } else if (name.compare(0, end, "fast") == 0) {
652            functionType = "FAST";
653        }
654    }
655
656    string floatType = mReturnParam->specType;
657    const char* precisionStr = "";
658    if (floatType.compare("f16") == 0) {
659        precisionStr = "HALF";
660    } else if (floatType.compare("f32") == 0) {
661        precisionStr = "FLOAT";
662    } else if (floatType.compare("f64") == 0) {
663        precisionStr = "DOUBLE";
664    } else {
665        cerr << "Error. Unreachable.  Return type is not floating point\n";
666    }
667
668    mJava->indent() << "Target target = new Target(Target.FunctionType." <<
669                    functionType << ", Target.ReturnType." << precisionStr <<
670                    ", relaxed);\n";
671}
672
673void PermutationWriter::writeJavaVerifyMethodHeader() const {
674    mJava->indent() << "private void " << mJavaVerifyMethodName << "(";
675    for (auto p : mAllInputsAndOutputs) {
676        *mJava << "Allocation " << p->javaAllocName << ", ";
677    }
678    *mJava << "boolean relaxed)";
679}
680
681void PermutationWriter::writeJavaArrayInitialization(const ParameterDefinition& p) const {
682    mJava->indent() << p.javaBaseType << "[] " << p.javaArrayName << " = new " << p.javaBaseType
683                    << "[INPUTSIZE * " << p.vectorWidth << "];\n";
684
685    /* For basic types, populate the array with values, to help understand failures.  We have had
686     * bugs where the output buffer was all 0.  We were not sure if there was a failed copy or
687     * the GPU driver was copying zeroes.
688     */
689    if (p.typeIndex >= 0) {
690        mJava->indent() << "Arrays.fill(" << p.javaArrayName << ", (" << TYPES[p.typeIndex].javaType
691                        << ") 42);\n";
692    }
693
694    mJava->indent() << p.javaAllocName << ".copyTo(" << p.javaArrayName << ");\n";
695}
696
697void PermutationWriter::writeJavaTestAndSetValid(const ParameterDefinition& p,
698                                                 const string& argsIndex,
699                                                 const string& actualIndex) const {
700    writeJavaTestOneValue(p, argsIndex, actualIndex);
701    mJava->startBlock();
702    mJava->indent() << "valid = false;\n";
703    mJava->endBlock();
704}
705
706void PermutationWriter::writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
707                                              const string& actualIndex) const {
708    string actualOut;
709    if (p.isFloat16Parameter()) {
710        // For Float16 values, the output needs to be converted to Double.
711        actualOut = "Float16Utils.convertFloat16ToDouble(" + p.javaArrayName + actualIndex + ")";
712    } else {
713        actualOut = p.javaArrayName + actualIndex;
714    }
715
716    mJava->indent() << "if (";
717    if (p.isFloatType) {
718        *mJava << "!args." << p.variableName << argsIndex << ".couldBe(" << actualOut;
719        const string s = mPermutation.getPrecisionLimit();
720        if (!s.empty()) {
721            *mJava << ", " << s;
722        }
723        *mJava << ")";
724    } else {
725        *mJava << "args." << p.variableName << argsIndex << " != " << p.javaArrayName
726               << actualIndex;
727    }
728
729    if (p.undefinedIfOutIsNan && mReturnParam) {
730        *mJava << " && !args." << mReturnParam->variableName << argsIndex << ".isNaN()";
731    }
732    *mJava << ")";
733}
734
735void PermutationWriter::writeJavaVectorComparison(const ParameterDefinition& p) const {
736    if (p.mVectorSize == "1") {
737        writeJavaTestAndSetValid(p, "", "[i]");
738    } else {
739        mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
740        mJava->startBlock();
741        writeJavaTestAndSetValid(p, "[j]", "[i * " + p.vectorWidth + " + j]");
742        mJava->endBlock();
743    }
744}
745
746void PermutationWriter::writeJavaAppendOutputToMessage(const ParameterDefinition& p,
747                                                       const string& argsIndex,
748                                                       const string& actualIndex,
749                                                       bool verifierValidates) const {
750    if (verifierValidates) {
751        mJava->indent() << "message.append(\"Output " << p.variableName << ": \");\n";
752        mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
753                        << ");\n";
754        writeJavaAppendNewLineToMessage();
755        if (p.isFloat16Parameter()) {
756            writeJavaAppendNewLineToMessage();
757            mJava->indent() << "message.append(\"Output " << p.variableName
758                            << " (in double): \");\n";
759            mJava->indent() << "appendVariableToMessage(message, args." << p.doubleVariableName
760                            << ");\n";
761            writeJavaAppendNewLineToMessage();
762        }
763    } else {
764        mJava->indent() << "message.append(\"Expected output " << p.variableName << ": \");\n";
765        mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
766                        << ");\n";
767        writeJavaAppendNewLineToMessage();
768
769        mJava->indent() << "message.append(\"Actual   output " << p.variableName << ": \");\n";
770        mJava->indent() << "appendVariableToMessage(message, " << p.javaArrayName << actualIndex
771                        << ");\n";
772
773        if (p.isFloat16Parameter()) {
774            writeJavaAppendNewLineToMessage();
775            mJava->indent() << "message.append(\"Actual   output " << p.variableName
776                            << " (in double): \");\n";
777            mJava->indent() << "appendVariableToMessage(message, Float16Utils.convertFloat16ToDouble("
778                            << p.javaArrayName << actualIndex << "));\n";
779        }
780
781        writeJavaTestOneValue(p, argsIndex, actualIndex);
782        mJava->startBlock();
783        mJava->indent() << "message.append(\" FAIL\");\n";
784        mJava->endBlock();
785        writeJavaAppendNewLineToMessage();
786    }
787}
788
789void PermutationWriter::writeJavaAppendInputToMessage(const ParameterDefinition& p,
790                                                      const string& actual) const {
791    mJava->indent() << "message.append(\"Input " << p.variableName << ": \");\n";
792    mJava->indent() << "appendVariableToMessage(message, " << actual << ");\n";
793    writeJavaAppendNewLineToMessage();
794}
795
796void PermutationWriter::writeJavaAppendNewLineToMessage() const {
797    mJava->indent() << "message.append(\"\\n\");\n";
798}
799
800void PermutationWriter::writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const {
801    if (p.mVectorSize == "1") {
802        writeJavaAppendInputToMessage(p, p.javaArrayName + "[i]");
803    } else {
804        mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
805        mJava->startBlock();
806        writeJavaAppendInputToMessage(p, p.javaArrayName + "[i * " + p.vectorWidth + " + j]");
807        mJava->endBlock();
808    }
809}
810
811void PermutationWriter::writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const {
812    if (p.mVectorSize == "1") {
813        writeJavaAppendOutputToMessage(p, "", "[i]", false);
814    } else {
815        mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
816        mJava->startBlock();
817        writeJavaAppendOutputToMessage(p, "[j]", "[i * " + p.vectorWidth + " + j]", false);
818        mJava->endBlock();
819    }
820}
821
822void PermutationWriter::writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const {
823    string script = "script";
824    if (relaxed) {
825        script += "Relaxed";
826    }
827
828    mJava->indent() << "try";
829    mJava->startBlock();
830
831    for (auto p : mAllInputsAndOutputs) {
832        if (p->isOutParameter) {
833            writeJavaOutputAllocationDefinition(*p);
834        }
835    }
836
837    for (auto p : mPermutation.getParams()) {
838        if (p != mFirstInputParam) {
839            mJava->indent() << script << ".set_" << p->rsAllocName << "(" << p->javaAllocName
840                            << ");\n";
841        }
842    }
843
844    mJava->indent() << script << ".forEach_" << mRsKernelName << "(";
845    bool needComma = false;
846    if (mFirstInputParam) {
847        *mJava << mFirstInputParam->javaAllocName;
848        needComma = true;
849    }
850    if (mReturnParam) {
851        if (needComma) {
852            *mJava << ", ";
853        }
854        *mJava << mReturnParam->variableName << ");\n";
855    }
856
857    if (generateCallToVerifier) {
858        mJava->indent() << mJavaVerifyMethodName << "(";
859        for (auto p : mAllInputsAndOutputs) {
860            *mJava << p->variableName << ", ";
861        }
862
863        if (relaxed) {
864            *mJava << "true";
865        } else {
866            *mJava << "false";
867        }
868        *mJava << ");\n";
869    }
870
871    // Generate code to destroy output Allocations.
872    for (auto p : mAllInputsAndOutputs) {
873        if (p->isOutParameter) {
874            mJava->indent() << p->javaAllocName << ".destroy();\n";
875        }
876    }
877
878    mJava->decreaseIndent();
879    mJava->indent() << "} catch (Exception e) {\n";
880    mJava->increaseIndent();
881    mJava->indent() << "throw new RSRuntimeException(\"RenderScript. Can't invoke forEach_"
882                    << mRsKernelName << ": \" + e.toString());\n";
883    mJava->endBlock();
884}
885
886/* Write the section of the .rs file for this permutation.
887 *
888 * We communicate the extra input and output parameters via global allocations.
889 * For example, if we have a function that takes three arguments, two for input
890 * and one for output:
891 *
892 * start:
893 * name: gamn
894 * ret: float3
895 * arg: float3 a
896 * arg: int b
897 * arg: float3 *c
898 * end:
899 *
900 * We'll produce:
901 *
902 * rs_allocation gAllocInB;
903 * rs_allocation gAllocOutC;
904 *
905 * float3 __attribute__((kernel)) test_gamn_float3_int_float3(float3 inA, unsigned int x) {
906 *    int inB;
907 *    float3 outC;
908 *    float2 out;
909 *    inB = rsGetElementAt_int(gAllocInB, x);
910 *    out = gamn(a, in_b, &outC);
911 *    rsSetElementAt_float4(gAllocOutC, &outC, x);
912 *    return out;
913 * }
914 *
915 * We avoid re-using x and y from the definition because these have reserved
916 * meanings in a .rs file.
917 */
918void PermutationWriter::writeRsSection(set<string>* rsAllocationsGenerated) const {
919    // Write the allocation declarations we'll need.
920    for (auto p : mPermutation.getParams()) {
921        // Don't need allocation for one input and one return value.
922        if (p != mFirstInputParam) {
923            writeRsAllocationDefinition(*p, rsAllocationsGenerated);
924        }
925    }
926    *mRs << "\n";
927
928    // Write the function header.
929    if (mReturnParam) {
930        *mRs << mReturnParam->rsType;
931    } else {
932        *mRs << "void";
933    }
934    *mRs << " __attribute__((kernel)) " << mRsKernelName;
935    *mRs << "(";
936    bool needComma = false;
937    if (mFirstInputParam) {
938        *mRs << mFirstInputParam->rsType << " " << mFirstInputParam->variableName;
939        needComma = true;
940    }
941    if (mPermutation.getOutputCount() > 1 || mPermutation.getInputCount() > 1) {
942        if (needComma) {
943            *mRs << ", ";
944        }
945        *mRs << "unsigned int x";
946    }
947    *mRs << ")";
948    mRs->startBlock();
949
950    // Write the local variable declarations and initializations.
951    for (auto p : mPermutation.getParams()) {
952        if (p == mFirstInputParam) {
953            continue;
954        }
955        mRs->indent() << p->rsType << " " << p->variableName;
956        if (p->isOutParameter) {
957            *mRs << " = 0;\n";
958        } else {
959            *mRs << " = rsGetElementAt_" << p->rsType << "(" << p->rsAllocName << ", x);\n";
960        }
961    }
962
963    // Write the function call.
964    if (mReturnParam) {
965        if (mPermutation.getOutputCount() > 1) {
966            mRs->indent() << mReturnParam->rsType << " " << mReturnParam->variableName << " = ";
967        } else {
968            mRs->indent() << "return ";
969        }
970    }
971    *mRs << mPermutation.getName() << "(";
972    needComma = false;
973    for (auto p : mPermutation.getParams()) {
974        if (needComma) {
975            *mRs << ", ";
976        }
977        if (p->isOutParameter) {
978            *mRs << "&";
979        }
980        *mRs << p->variableName;
981        needComma = true;
982    }
983    *mRs << ");\n";
984
985    if (mPermutation.getOutputCount() > 1) {
986        // Write setting the extra out parameters into the allocations.
987        for (auto p : mPermutation.getParams()) {
988            if (p->isOutParameter) {
989                mRs->indent() << "rsSetElementAt_" << p->rsType << "(" << p->rsAllocName << ", ";
990                // Check if we need to use '&' for this type of argument.
991                char lastChar = p->variableName.back();
992                if (lastChar >= '0' && lastChar <= '9') {
993                    *mRs << "&";
994                }
995                *mRs << p->variableName << ", x);\n";
996            }
997        }
998        if (mReturnParam) {
999            mRs->indent() << "return " << mReturnParam->variableName << ";\n";
1000        }
1001    }
1002    mRs->endBlock();
1003}
1004
1005void PermutationWriter::writeRsAllocationDefinition(const ParameterDefinition& param,
1006                                                    set<string>* rsAllocationsGenerated) const {
1007    if (!testAndSet(param.rsAllocName, rsAllocationsGenerated)) {
1008        *mRs << "rs_allocation " << param.rsAllocName << ";\n";
1009    }
1010}
1011
1012// Open the mJavaFile and writes the header.
1013static bool startJavaFile(GeneratedFile* file, const string& directory,
1014                          const string& testName,
1015                          const string& relaxedTestName) {
1016    const string fileName = testName + ".java";
1017    if (!file->start(directory, fileName)) {
1018        return false;
1019    }
1020    file->writeNotices();
1021
1022    *file << "package android.renderscript.cts;\n\n";
1023
1024    *file << "import android.renderscript.Allocation;\n";
1025    *file << "import android.renderscript.RSRuntimeException;\n";
1026    *file << "import android.renderscript.Element;\n";
1027    *file << "import android.renderscript.cts.Target;\n\n";
1028    *file << "import java.util.Arrays;\n\n";
1029
1030    *file << "public class " << testName << " extends RSBaseCompute";
1031    file->startBlock();  // The corresponding endBlock() is in finishJavaFile()
1032    *file << "\n";
1033
1034    file->indent() << "private ScriptC_" << testName << " script;\n";
1035    file->indent() << "private ScriptC_" << relaxedTestName << " scriptRelaxed;\n\n";
1036
1037    file->indent() << "@Override\n";
1038    file->indent() << "protected void setUp() throws Exception";
1039    file->startBlock();
1040
1041    file->indent() << "super.setUp();\n";
1042    file->indent() << "script = new ScriptC_" << testName << "(mRS);\n";
1043    file->indent() << "scriptRelaxed = new ScriptC_" << relaxedTestName << "(mRS);\n";
1044
1045    file->endBlock();
1046    *file << "\n";
1047
1048    file->indent() << "@Override\n";
1049    file->indent() << "protected void tearDown() throws Exception";
1050    file->startBlock();
1051
1052    file->indent() << "script.destroy();\n";
1053    file->indent() << "scriptRelaxed.destroy();\n";
1054    file->indent() << "super.tearDown();\n";
1055
1056    file->endBlock();
1057    *file << "\n";
1058
1059    return true;
1060}
1061
1062// Write the test method that calls all the generated Check methods.
1063static void finishJavaFile(GeneratedFile* file, const Function& function,
1064                           const vector<string>& javaCheckMethods) {
1065    file->indent() << "public void test" << function.getCapitalizedName() << "()";
1066    file->startBlock();
1067    for (auto m : javaCheckMethods) {
1068        file->indent() << m << "();\n";
1069    }
1070    file->endBlock();
1071
1072    file->endBlock();
1073}
1074
1075// Open the script file and write its header.
1076static bool startRsFile(GeneratedFile* file, const string& directory,
1077                        const string& testName) {
1078    string fileName = testName + ".rs";
1079    if (!file->start(directory, fileName)) {
1080        return false;
1081    }
1082    file->writeNotices();
1083
1084    *file << "#pragma version(1)\n";
1085    *file << "#pragma rs java_package_name(android.renderscript.cts)\n\n";
1086    return true;
1087}
1088
1089// Write the entire *Relaxed.rs test file, as it only depends on the name.
1090static bool writeRelaxedRsFile(const string& directory, const string& testName,
1091                               const string& relaxedTestName) {
1092    string name = relaxedTestName + ".rs";
1093
1094    GeneratedFile file;
1095    if (!file.start(directory, name)) {
1096        return false;
1097    }
1098    file.writeNotices();
1099
1100    file << "#include \"" << testName << ".rs\"\n";
1101    file << "#pragma rs_fp_relaxed\n";
1102    file.close();
1103    return true;
1104}
1105
1106/* Write the .java and the two .rs test files.  versionOfTestFiles is used to restrict which API
1107 * to test.
1108 */
1109static bool writeTestFilesForFunction(const Function& function, const string& directory,
1110                                      unsigned int versionOfTestFiles) {
1111    // Avoid creating empty files if we're not testing this function.
1112    if (!needTestFiles(function, versionOfTestFiles)) {
1113        return true;
1114    }
1115
1116    const string testName = "Test" + function.getCapitalizedName();
1117    const string relaxedTestName = testName + "Relaxed";
1118
1119    if (!writeRelaxedRsFile(directory, testName, relaxedTestName)) {
1120        return false;
1121    }
1122
1123    GeneratedFile rsFile;    // The Renderscript test file we're generating.
1124    GeneratedFile javaFile;  // The Jave test file we're generating.
1125    if (!startRsFile(&rsFile, directory, testName)) {
1126        return false;
1127    }
1128
1129    if (!startJavaFile(&javaFile, directory, testName, relaxedTestName)) {
1130        return false;
1131    }
1132
1133    /* We keep track of the allocations generated in the .rs file and the argument classes defined
1134     * in the Java file, as we share these between the functions created for each specification.
1135     */
1136    set<string> rsAllocationsGenerated;
1137    set<string> javaGeneratedArgumentClasses;
1138    // Lines of Java code to invoke the check methods.
1139    vector<string> javaCheckMethods;
1140
1141    for (auto spec : function.getSpecifications()) {
1142        if (spec->hasTests(versionOfTestFiles)) {
1143            for (auto permutation : spec->getPermutations()) {
1144                PermutationWriter w(*permutation, &rsFile, &javaFile);
1145                w.writeRsSection(&rsAllocationsGenerated);
1146                w.writeJavaSection(&javaGeneratedArgumentClasses);
1147
1148                // Store the check method to be called.
1149                javaCheckMethods.push_back(w.getJavaCheckMethodName());
1150            }
1151        }
1152    }
1153
1154    finishJavaFile(&javaFile, function, javaCheckMethods);
1155    // There's no work to wrap-up in the .rs file.
1156
1157    rsFile.close();
1158    javaFile.close();
1159    return true;
1160}
1161
1162bool generateTestFiles(const string& directory, unsigned int versionOfTestFiles) {
1163    bool success = true;
1164    for (auto f : systemSpecification.getFunctions()) {
1165        if (!writeTestFilesForFunction(*f.second, directory, versionOfTestFiles)) {
1166            success = false;
1167        }
1168    }
1169    return success;
1170}
1171