1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <iomanip>
18#include <iostream>
19#include <cmath>
20#include <sstream>
21
22#include "Generator.h"
23#include "Specification.h"
24#include "Utilities.h"
25
26using namespace std;
27
28// Converts float2 to FLOAT_32 and 2, etc.
29static void convertToRsType(const string& name, string* dataType, char* vectorSize) {
30    string s = name;
31    int last = s.size() - 1;
32    char lastChar = s[last];
33    if (lastChar >= '1' && lastChar <= '4') {
34        s.erase(last);
35        *vectorSize = lastChar;
36    } else {
37        *vectorSize = '1';
38    }
39    dataType->clear();
40    for (int i = 0; i < NUM_TYPES; i++) {
41        if (s == TYPES[i].cType) {
42            *dataType = TYPES[i].rsDataType;
43            break;
44        }
45    }
46}
47
48// Returns true if any permutation of the function have tests to b
49static bool needTestFiles(const Function& function, int versionOfTestFiles) {
50    for (auto spec : function.getSpecifications()) {
51        if (spec->hasTests(versionOfTestFiles)) {
52            return true;
53        }
54    }
55    return false;
56}
57
58/* One instance of this class is generated for each permutation of a function for which
59 * we are generating test code.  This instance will generate both the script and the Java
60 * section of the test files for this permutation.  The class is mostly used to keep track
61 * of the various names shared between script and Java files.
62 * WARNING: Because the constructor keeps a reference to the FunctionPermutation, PermutationWriter
63 * should not exceed the lifetime of FunctionPermutation.
64 */
65class PermutationWriter {
66private:
67    FunctionPermutation& mPermutation;
68
69    string mRsKernelName;
70    string mJavaArgumentsClassName;
71    string mJavaArgumentsNClassName;
72    string mJavaVerifierComputeMethodName;
73    string mJavaVerifierVerifyMethodName;
74    string mJavaCheckMethodName;
75    string mJavaVerifyMethodName;
76
77    // Pointer to the files we are generating.  Handy to avoid always passing them in the calls.
78    GeneratedFile* mRs;
79    GeneratedFile* mJava;
80
81    /* Shortcuts to the return parameter and the first input parameter of the function
82     * specification.
83     */
84    const ParameterDefinition* mReturnParam;      // Can be nullptr.  NOT OWNED.
85    const ParameterDefinition* mFirstInputParam;  // Can be nullptr.  NOT OWNED.
86
87    /* All the parameters plus the return param, if present.  Collecting them together
88     * simplifies code generation.  NOT OWNED.
89     */
90    vector<const ParameterDefinition*> mAllInputsAndOutputs;
91
92    /* We use a class to pass the arguments between the generated code and the CoreVerifier.  This
93     * method generates this class.  The set keeps track if we've generated this class already
94     * for this test file, as more than one permutation may use the same argument class.
95     */
96    void writeJavaArgumentClass(bool scalar, set<string>* javaGeneratedArgumentClasses) const;
97
98    // Generate the Check* method that invokes the script and calls the verifier.
99    void writeJavaCheckMethod(bool generateCallToVerifier) const;
100
101    // Generate code to define and randomly initialize the input allocation.
102    void writeJavaInputAllocationDefinition(const ParameterDefinition& param) const;
103
104    /* Generate code that instantiate an allocation of floats or integers and fills it with
105     * random data. This random data must be compatible with the specified type.  This is
106     * used for the convert_* tests, as converting values that don't fit yield undefined results.
107     */
108    void writeJavaRandomCompatibleFloatAllocation(const string& dataType, const string& seed,
109                                                  char vectorSize,
110                                                  const NumericalType& compatibleType,
111                                                  const NumericalType& generatedType) const;
112    void writeJavaRandomCompatibleIntegerAllocation(const string& dataType, const string& seed,
113                                                    char vectorSize,
114                                                    const NumericalType& compatibleType,
115                                                    const NumericalType& generatedType) const;
116
117    // Generate code that defines an output allocation.
118    void writeJavaOutputAllocationDefinition(const ParameterDefinition& param) const;
119
120    /* Generate the code that verifies the results for RenderScript functions where each entry
121     * of a vector is evaluated independently.  If verifierValidates is true, CoreMathVerifier
122     * does the actual validation instead of more commonly returning the range of acceptable values.
123     */
124    void writeJavaVerifyScalarMethod(bool verifierValidates) const;
125
126    /* Generate the code that verify the results for a RenderScript function where a vector
127     * is a point in n-dimensional space.
128     */
129    void writeJavaVerifyVectorMethod() const;
130
131    // Generate the method header of the verify function.
132    void writeJavaVerifyMethodHeader() const;
133
134    // Generate codes that copies the content of an allocation to an array.
135    void writeJavaArrayInitialization(const ParameterDefinition& p) const;
136
137    // Generate code that tests one value returned from the script.
138    void writeJavaTestAndSetValid(const ParameterDefinition& p, const string& argsIndex,
139                                  const string& actualIndex) const;
140    void writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
141                               const string& actualIndex) const;
142    // For test:vector cases, generate code that compares returned vector vs. expected value.
143    void writeJavaVectorComparison(const ParameterDefinition& p) const;
144
145    // Muliple functions that generates code to build the error message if an error is found.
146    void writeJavaAppendOutputToMessage(const ParameterDefinition& p, const string& argsIndex,
147                                        const string& actualIndex, bool verifierValidates) const;
148    void writeJavaAppendInputToMessage(const ParameterDefinition& p, const string& actual) const;
149    void writeJavaAppendNewLineToMessage() const;
150    void writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const;
151    void writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const;
152
153    // Generate the set of instructions to call the script.
154    void writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const;
155
156    // Write an allocation definition if not already emitted in the .rs file.
157    void writeRsAllocationDefinition(const ParameterDefinition& param,
158                                     set<string>* rsAllocationsGenerated) const;
159
160public:
161    /* NOTE: We keep pointers to the permutation and the files.  This object should not
162     * outlive the arguments.
163     */
164    PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
165                      GeneratedFile* javaFile);
166    string getJavaCheckMethodName() const { return mJavaCheckMethodName; }
167
168    // Write the script test function for this permutation.
169    void writeRsSection(set<string>* rsAllocationsGenerated) const;
170    // Write the section of the Java code that calls the script and validates the results
171    void writeJavaSection(set<string>* javaGeneratedArgumentClasses) const;
172};
173
174PermutationWriter::PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
175                                     GeneratedFile* javaFile)
176    : mPermutation(permutation),
177      mRs(rsFile),
178      mJava(javaFile),
179      mReturnParam(nullptr),
180      mFirstInputParam(nullptr) {
181    mRsKernelName = "test" + capitalize(permutation.getName());
182
183    mJavaArgumentsClassName = "Arguments";
184    mJavaArgumentsNClassName = "Arguments";
185    const string trunk = capitalize(permutation.getNameTrunk());
186    mJavaCheckMethodName = "check" + trunk;
187    mJavaVerifyMethodName = "verifyResults" + trunk;
188
189    for (auto p : permutation.getParams()) {
190        mAllInputsAndOutputs.push_back(p);
191        if (mFirstInputParam == nullptr && !p->isOutParameter) {
192            mFirstInputParam = p;
193        }
194    }
195    mReturnParam = permutation.getReturn();
196    if (mReturnParam) {
197        mAllInputsAndOutputs.push_back(mReturnParam);
198    }
199
200    for (auto p : mAllInputsAndOutputs) {
201        const string capitalizedRsType = capitalize(p->rsType);
202        const string capitalizedBaseType = capitalize(p->rsBaseType);
203        mRsKernelName += capitalizedRsType;
204        mJavaArgumentsClassName += capitalizedBaseType;
205        mJavaArgumentsNClassName += capitalizedBaseType;
206        if (p->mVectorSize != "1") {
207            mJavaArgumentsNClassName += "N";
208        }
209        mJavaCheckMethodName += capitalizedRsType;
210        mJavaVerifyMethodName += capitalizedRsType;
211    }
212    mJavaVerifierComputeMethodName = "compute" + trunk;
213    mJavaVerifierVerifyMethodName = "verify" + trunk;
214}
215
216void PermutationWriter::writeJavaSection(set<string>* javaGeneratedArgumentClasses) const {
217    // By default, we test the results using item by item comparison.
218    const string test = mPermutation.getTest();
219    if (test == "scalar" || test == "limited") {
220        writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
221        writeJavaCheckMethod(true);
222        writeJavaVerifyScalarMethod(false);
223    } else if (test == "custom") {
224        writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
225        writeJavaCheckMethod(true);
226        writeJavaVerifyScalarMethod(true);
227    } else if (test == "vector") {
228        writeJavaArgumentClass(false, javaGeneratedArgumentClasses);
229        writeJavaCheckMethod(true);
230        writeJavaVerifyVectorMethod();
231    } else if (test == "noverify") {
232        writeJavaCheckMethod(false);
233    }
234}
235
236void PermutationWriter::writeJavaArgumentClass(bool scalar,
237                                               set<string>* javaGeneratedArgumentClasses) const {
238    string name;
239    if (scalar) {
240        name = mJavaArgumentsClassName;
241    } else {
242        name = mJavaArgumentsNClassName;
243    }
244
245    // Make sure we have not generated the argument class already.
246    if (!testAndSet(name, javaGeneratedArgumentClasses)) {
247        mJava->indent() << "public class " << name;
248        mJava->startBlock();
249
250        for (auto p : mAllInputsAndOutputs) {
251            mJava->indent() << "public ";
252            if (p->isOutParameter && p->isFloatType && mPermutation.getTest() != "custom") {
253                *mJava << "Target.Floaty";
254            } else {
255                *mJava << p->javaBaseType;
256            }
257            if (!scalar && p->mVectorSize != "1") {
258                *mJava << "[]";
259            }
260            *mJava << " " << p->variableName << ";\n";
261        }
262        mJava->endBlock();
263        *mJava << "\n";
264    }
265}
266
267void PermutationWriter::writeJavaCheckMethod(bool generateCallToVerifier) const {
268    mJava->indent() << "private void " << mJavaCheckMethodName << "()";
269    mJava->startBlock();
270
271    // Generate the input allocations and initialization.
272    for (auto p : mAllInputsAndOutputs) {
273        if (!p->isOutParameter) {
274            writeJavaInputAllocationDefinition(*p);
275        }
276    }
277    // Generate code to enforce ordering between two allocations if needed.
278    for (auto p : mAllInputsAndOutputs) {
279        if (!p->isOutParameter && !p->smallerParameter.empty()) {
280            string smallerAlloc = "in" + capitalize(p->smallerParameter);
281            mJava->indent() << "enforceOrdering(" << smallerAlloc << ", " << p->javaAllocName
282                            << ");\n";
283        }
284    }
285
286    // Generate code to check the full and relaxed scripts.
287    writeJavaCallToRs(false, generateCallToVerifier);
288    writeJavaCallToRs(true, generateCallToVerifier);
289
290    mJava->endBlock();
291    *mJava << "\n";
292}
293
294void PermutationWriter::writeJavaInputAllocationDefinition(const ParameterDefinition& param) const {
295    string dataType;
296    char vectorSize;
297    convertToRsType(param.rsType, &dataType, &vectorSize);
298
299    const string seed = hashString(mJavaCheckMethodName + param.javaAllocName);
300    mJava->indent() << "Allocation " << param.javaAllocName << " = ";
301    if (param.compatibleTypeIndex >= 0) {
302        if (TYPES[param.typeIndex].kind == FLOATING_POINT) {
303            writeJavaRandomCompatibleFloatAllocation(dataType, seed, vectorSize,
304                                                     TYPES[param.compatibleTypeIndex],
305                                                     TYPES[param.typeIndex]);
306        } else {
307            writeJavaRandomCompatibleIntegerAllocation(dataType, seed, vectorSize,
308                                                       TYPES[param.compatibleTypeIndex],
309                                                       TYPES[param.typeIndex]);
310        }
311    } else if (!param.minValue.empty()) {
312        *mJava << "createRandomFloatAllocation(mRS, Element.DataType." << dataType << ", "
313               << vectorSize << ", " << seed << ", " << param.minValue << ", " << param.maxValue
314               << ")";
315    } else {
316        /* TODO Instead of passing always false, check whether we are doing a limited test.
317         * Use instead: (mPermutation.getTest() == "limited" ? "false" : "true")
318         */
319        *mJava << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize
320               << ", " << seed << ", false)";
321    }
322    *mJava << ";\n";
323}
324
325void PermutationWriter::writeJavaRandomCompatibleFloatAllocation(
326            const string& dataType, const string& seed, char vectorSize,
327            const NumericalType& compatibleType, const NumericalType& generatedType) const {
328    *mJava << "createRandomFloatAllocation"
329           << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
330    double minValue = 0.0;
331    double maxValue = 0.0;
332    switch (compatibleType.kind) {
333        case FLOATING_POINT: {
334            // We're generating floating point values.  We just worry about the exponent.
335            // Subtract 1 for the exponent sign.
336            int bits = min(compatibleType.exponentBits, generatedType.exponentBits) - 1;
337            maxValue = ldexp(0.95, (1 << bits) - 1);
338            minValue = -maxValue;
339            break;
340        }
341        case UNSIGNED_INTEGER:
342            maxValue = maxDoubleForInteger(compatibleType.significantBits,
343                                           generatedType.significantBits);
344            minValue = 0.0;
345            break;
346        case SIGNED_INTEGER:
347            maxValue = maxDoubleForInteger(compatibleType.significantBits,
348                                           generatedType.significantBits);
349            minValue = -maxValue - 1.0;
350            break;
351    }
352    *mJava << scientific << std::setprecision(19);
353    *mJava << minValue << ", " << maxValue << ")";
354    mJava->unsetf(ios_base::floatfield);
355}
356
357void PermutationWriter::writeJavaRandomCompatibleIntegerAllocation(
358            const string& dataType, const string& seed, char vectorSize,
359            const NumericalType& compatibleType, const NumericalType& generatedType) const {
360    *mJava << "createRandomIntegerAllocation"
361           << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
362
363    if (compatibleType.kind == FLOATING_POINT) {
364        // Currently, all floating points can take any number we generate.
365        bool isSigned = generatedType.kind == SIGNED_INTEGER;
366        *mJava << (isSigned ? "true" : "false") << ", " << generatedType.significantBits;
367    } else {
368        bool isSigned =
369                    compatibleType.kind == SIGNED_INTEGER && generatedType.kind == SIGNED_INTEGER;
370        *mJava << (isSigned ? "true" : "false") << ", "
371               << min(compatibleType.significantBits, generatedType.significantBits);
372    }
373    *mJava << ")";
374}
375
376void PermutationWriter::writeJavaOutputAllocationDefinition(
377            const ParameterDefinition& param) const {
378    string dataType;
379    char vectorSize;
380    convertToRsType(param.rsType, &dataType, &vectorSize);
381    mJava->indent() << "Allocation " << param.javaAllocName << " = Allocation.createSized(mRS, "
382                    << "getElement(mRS, Element.DataType." << dataType << ", " << vectorSize
383                    << "), INPUTSIZE);\n";
384}
385
386void PermutationWriter::writeJavaVerifyScalarMethod(bool verifierValidates) const {
387    writeJavaVerifyMethodHeader();
388    mJava->startBlock();
389
390    string vectorSize = "1";
391    for (auto p : mAllInputsAndOutputs) {
392        writeJavaArrayInitialization(*p);
393        if (p->mVectorSize != "1" && p->mVectorSize != vectorSize) {
394            if (vectorSize == "1") {
395                vectorSize = p->mVectorSize;
396            } else {
397                cerr << "Error.  Had vector " << vectorSize << " and " << p->mVectorSize << "\n";
398            }
399        }
400    }
401
402    mJava->indent() << "StringBuilder message = new StringBuilder();\n";
403    mJava->indent() << "boolean errorFound = false;\n";
404    mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
405    mJava->startBlock();
406
407    mJava->indent() << "for (int j = 0; j < " << vectorSize << " ; j++)";
408    mJava->startBlock();
409
410    mJava->indent() << "// Extract the inputs.\n";
411    mJava->indent() << mJavaArgumentsClassName << " args = new " << mJavaArgumentsClassName
412                    << "();\n";
413    for (auto p : mAllInputsAndOutputs) {
414        if (!p->isOutParameter) {
415            mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i";
416            if (p->vectorWidth != "1") {
417                *mJava << " * " << p->vectorWidth << " + j";
418            }
419            *mJava << "];\n";
420        }
421    }
422    const bool hasFloat = mPermutation.hasFloatAnswers();
423    if (verifierValidates) {
424        mJava->indent() << "// Extract the outputs.\n";
425        for (auto p : mAllInputsAndOutputs) {
426            if (p->isOutParameter) {
427                mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName
428                                << "[i * " << p->vectorWidth << " + j];\n";
429            }
430        }
431        mJava->indent() << "// Ask the CoreMathVerifier to validate.\n";
432        if (hasFloat) {
433            mJava->indent() << "Target target = new Target(relaxed);\n";
434        }
435        mJava->indent() << "String errorMessage = CoreMathVerifier."
436                        << mJavaVerifierVerifyMethodName << "(args";
437        if (hasFloat) {
438            *mJava << ", target";
439        }
440        *mJava << ");\n";
441        mJava->indent() << "boolean valid = errorMessage == null;\n";
442    } else {
443        mJava->indent() << "// Figure out what the outputs should have been.\n";
444        if (hasFloat) {
445            mJava->indent() << "Target target = new Target(relaxed);\n";
446        }
447        mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args";
448        if (hasFloat) {
449            *mJava << ", target";
450        }
451        *mJava << ");\n";
452        mJava->indent() << "// Validate the outputs.\n";
453        mJava->indent() << "boolean valid = true;\n";
454        for (auto p : mAllInputsAndOutputs) {
455            if (p->isOutParameter) {
456                writeJavaTestAndSetValid(*p, "", "[i * " + p->vectorWidth + " + j]");
457            }
458        }
459    }
460
461    mJava->indent() << "if (!valid)";
462    mJava->startBlock();
463    mJava->indent() << "if (!errorFound)";
464    mJava->startBlock();
465    mJava->indent() << "errorFound = true;\n";
466
467    for (auto p : mAllInputsAndOutputs) {
468        if (p->isOutParameter) {
469            writeJavaAppendOutputToMessage(*p, "", "[i * " + p->vectorWidth + " + j]",
470                                           verifierValidates);
471        } else {
472            writeJavaAppendInputToMessage(*p, "args." + p->variableName);
473        }
474    }
475    if (verifierValidates) {
476        mJava->indent() << "message.append(errorMessage);\n";
477    }
478    mJava->indent() << "message.append(\"Errors at\");\n";
479    mJava->endBlock();
480
481    mJava->indent() << "message.append(\" [\");\n";
482    mJava->indent() << "message.append(Integer.toString(i));\n";
483    mJava->indent() << "message.append(\", \");\n";
484    mJava->indent() << "message.append(Integer.toString(j));\n";
485    mJava->indent() << "message.append(\"]\");\n";
486
487    mJava->endBlock();
488    mJava->endBlock();
489    mJava->endBlock();
490
491    mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
492    mJava->indentPlus()
493                << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
494
495    mJava->endBlock();
496    *mJava << "\n";
497}
498
499void PermutationWriter::writeJavaVerifyVectorMethod() const {
500    writeJavaVerifyMethodHeader();
501    mJava->startBlock();
502
503    for (auto p : mAllInputsAndOutputs) {
504        writeJavaArrayInitialization(*p);
505    }
506    mJava->indent() << "StringBuilder message = new StringBuilder();\n";
507    mJava->indent() << "boolean errorFound = false;\n";
508    mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
509    mJava->startBlock();
510
511    mJava->indent() << mJavaArgumentsNClassName << " args = new " << mJavaArgumentsNClassName
512                    << "();\n";
513
514    mJava->indent() << "// Create the appropriate sized arrays in args\n";
515    for (auto p : mAllInputsAndOutputs) {
516        if (p->mVectorSize != "1") {
517            string type = p->javaBaseType;
518            if (p->isOutParameter && p->isFloatType) {
519                type = "Target.Floaty";
520            }
521            mJava->indent() << "args." << p->variableName << " = new " << type << "["
522                            << p->mVectorSize << "];\n";
523        }
524    }
525
526    mJava->indent() << "// Fill args with the input values\n";
527    for (auto p : mAllInputsAndOutputs) {
528        if (!p->isOutParameter) {
529            if (p->mVectorSize == "1") {
530                mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i]"
531                                << ";\n";
532            } else {
533                mJava->indent() << "for (int j = 0; j < " << p->mVectorSize << " ; j++)";
534                mJava->startBlock();
535                mJava->indent() << "args." << p->variableName << "[j] = "
536                                << p->javaArrayName << "[i * " << p->vectorWidth << " + j]"
537                                << ";\n";
538                mJava->endBlock();
539            }
540        }
541    }
542    mJava->indent() << "Target target = new Target(relaxed);\n";
543    mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName
544                    << "(args, target);\n\n";
545
546    mJava->indent() << "// Compare the expected outputs to the actual values returned by RS.\n";
547    mJava->indent() << "boolean valid = true;\n";
548    for (auto p : mAllInputsAndOutputs) {
549        if (p->isOutParameter) {
550            writeJavaVectorComparison(*p);
551        }
552    }
553
554    mJava->indent() << "if (!valid)";
555    mJava->startBlock();
556    mJava->indent() << "if (!errorFound)";
557    mJava->startBlock();
558    mJava->indent() << "errorFound = true;\n";
559
560    for (auto p : mAllInputsAndOutputs) {
561        if (p->isOutParameter) {
562            writeJavaAppendVectorOutputToMessage(*p);
563        } else {
564            writeJavaAppendVectorInputToMessage(*p);
565        }
566    }
567    mJava->indent() << "message.append(\"Errors at\");\n";
568    mJava->endBlock();
569
570    mJava->indent() << "message.append(\" [\");\n";
571    mJava->indent() << "message.append(Integer.toString(i));\n";
572    mJava->indent() << "message.append(\"]\");\n";
573
574    mJava->endBlock();
575    mJava->endBlock();
576
577    mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
578    mJava->indentPlus()
579                << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
580
581    mJava->endBlock();
582    *mJava << "\n";
583}
584
585void PermutationWriter::writeJavaVerifyMethodHeader() const {
586    mJava->indent() << "private void " << mJavaVerifyMethodName << "(";
587    for (auto p : mAllInputsAndOutputs) {
588        *mJava << "Allocation " << p->javaAllocName << ", ";
589    }
590    *mJava << "boolean relaxed)";
591}
592
593void PermutationWriter::writeJavaArrayInitialization(const ParameterDefinition& p) const {
594    mJava->indent() << p.javaBaseType << "[] " << p.javaArrayName << " = new " << p.javaBaseType
595                    << "[INPUTSIZE * " << p.vectorWidth << "];\n";
596
597    /* For basic types, populate the array with values, to help understand failures.  We have had
598     * bugs where the output buffer was all 0.  We were not sure if there was a failed copy or
599     * the GPU driver was copying zeroes.
600     */
601    if (p.typeIndex >= 0) {
602        mJava->indent() << "Arrays.fill(" << p.javaArrayName << ", (" << TYPES[p.typeIndex].javaType
603                        << ") 42);\n";
604    }
605
606    mJava->indent() << p.javaAllocName << ".copyTo(" << p.javaArrayName << ");\n";
607}
608
609void PermutationWriter::writeJavaTestAndSetValid(const ParameterDefinition& p,
610                                                 const string& argsIndex,
611                                                 const string& actualIndex) const {
612    writeJavaTestOneValue(p, argsIndex, actualIndex);
613    mJava->startBlock();
614    mJava->indent() << "valid = false;\n";
615    mJava->endBlock();
616}
617
618void PermutationWriter::writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
619                                              const string& actualIndex) const {
620    mJava->indent() << "if (";
621    if (p.isFloatType) {
622        *mJava << "!args." << p.variableName << argsIndex << ".couldBe(" << p.javaArrayName
623               << actualIndex;
624        const string s = mPermutation.getPrecisionLimit();
625        if (!s.empty()) {
626            *mJava << ", " << s;
627        }
628        *mJava << ")";
629    } else {
630        *mJava << "args." << p.variableName << argsIndex << " != " << p.javaArrayName
631               << actualIndex;
632    }
633
634    if (p.undefinedIfOutIsNan && mReturnParam) {
635        *mJava << " && !args." << mReturnParam->variableName << argsIndex << ".isNaN()";
636    }
637    *mJava << ")";
638}
639
640void PermutationWriter::writeJavaVectorComparison(const ParameterDefinition& p) const {
641    if (p.mVectorSize == "1") {
642        writeJavaTestAndSetValid(p, "", "[i]");
643    } else {
644        mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
645        mJava->startBlock();
646        writeJavaTestAndSetValid(p, "[j]", "[i * " + p.vectorWidth + " + j]");
647        mJava->endBlock();
648    }
649}
650
651void PermutationWriter::writeJavaAppendOutputToMessage(const ParameterDefinition& p,
652                                                       const string& argsIndex,
653                                                       const string& actualIndex,
654                                                       bool verifierValidates) const {
655    if (verifierValidates) {
656        mJava->indent() << "message.append(\"Output " << p.variableName << ": \");\n";
657        mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
658                        << ");\n";
659        writeJavaAppendNewLineToMessage();
660    } else {
661        mJava->indent() << "message.append(\"Expected output " << p.variableName << ": \");\n";
662        mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
663                        << ");\n";
664        writeJavaAppendNewLineToMessage();
665
666        mJava->indent() << "message.append(\"Actual   output " << p.variableName << ": \");\n";
667        mJava->indent() << "appendVariableToMessage(message, " << p.javaArrayName << actualIndex
668                        << ");\n";
669
670        writeJavaTestOneValue(p, argsIndex, actualIndex);
671        mJava->startBlock();
672        mJava->indent() << "message.append(\" FAIL\");\n";
673        mJava->endBlock();
674        writeJavaAppendNewLineToMessage();
675    }
676}
677
678void PermutationWriter::writeJavaAppendInputToMessage(const ParameterDefinition& p,
679                                                      const string& actual) const {
680    mJava->indent() << "message.append(\"Input " << p.variableName << ": \");\n";
681    mJava->indent() << "appendVariableToMessage(message, " << actual << ");\n";
682    writeJavaAppendNewLineToMessage();
683}
684
685void PermutationWriter::writeJavaAppendNewLineToMessage() const {
686    mJava->indent() << "message.append(\"\\n\");\n";
687}
688
689void PermutationWriter::writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const {
690    if (p.mVectorSize == "1") {
691        writeJavaAppendInputToMessage(p, p.javaArrayName + "[i]");
692    } else {
693        mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
694        mJava->startBlock();
695        writeJavaAppendInputToMessage(p, p.javaArrayName + "[i * " + p.vectorWidth + " + j]");
696        mJava->endBlock();
697    }
698}
699
700void PermutationWriter::writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const {
701    if (p.mVectorSize == "1") {
702        writeJavaAppendOutputToMessage(p, "", "[i]", false);
703    } else {
704        mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
705        mJava->startBlock();
706        writeJavaAppendOutputToMessage(p, "[j]", "[i * " + p.vectorWidth + " + j]", false);
707        mJava->endBlock();
708    }
709}
710
711void PermutationWriter::writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const {
712    string script = "script";
713    if (relaxed) {
714        script += "Relaxed";
715    }
716
717    mJava->indent() << "try";
718    mJava->startBlock();
719
720    for (auto p : mAllInputsAndOutputs) {
721        if (p->isOutParameter) {
722            writeJavaOutputAllocationDefinition(*p);
723        }
724    }
725
726    for (auto p : mPermutation.getParams()) {
727        if (p != mFirstInputParam) {
728            mJava->indent() << script << ".set_" << p->rsAllocName << "(" << p->javaAllocName
729                            << ");\n";
730        }
731    }
732
733    mJava->indent() << script << ".forEach_" << mRsKernelName << "(";
734    bool needComma = false;
735    if (mFirstInputParam) {
736        *mJava << mFirstInputParam->javaAllocName;
737        needComma = true;
738    }
739    if (mReturnParam) {
740        if (needComma) {
741            *mJava << ", ";
742        }
743        *mJava << mReturnParam->variableName << ");\n";
744    }
745
746    if (generateCallToVerifier) {
747        mJava->indent() << mJavaVerifyMethodName << "(";
748        for (auto p : mAllInputsAndOutputs) {
749            *mJava << p->variableName << ", ";
750        }
751
752        if (relaxed) {
753            *mJava << "true";
754        } else {
755            *mJava << "false";
756        }
757        *mJava << ");\n";
758    }
759    mJava->decreaseIndent();
760    mJava->indent() << "} catch (Exception e) {\n";
761    mJava->increaseIndent();
762    mJava->indent() << "throw new RSRuntimeException(\"RenderScript. Can't invoke forEach_"
763                    << mRsKernelName << ": \" + e.toString());\n";
764    mJava->endBlock();
765}
766
767/* Write the section of the .rs file for this permutation.
768 *
769 * We communicate the extra input and output parameters via global allocations.
770 * For example, if we have a function that takes three arguments, two for input
771 * and one for output:
772 *
773 * start:
774 * name: gamn
775 * ret: float3
776 * arg: float3 a
777 * arg: int b
778 * arg: float3 *c
779 * end:
780 *
781 * We'll produce:
782 *
783 * rs_allocation gAllocInB;
784 * rs_allocation gAllocOutC;
785 *
786 * float3 __attribute__((kernel)) test_gamn_float3_int_float3(float3 inA, unsigned int x) {
787 *    int inB;
788 *    float3 outC;
789 *    float2 out;
790 *    inB = rsGetElementAt_int(gAllocInB, x);
791 *    out = gamn(a, in_b, &outC);
792 *    rsSetElementAt_float4(gAllocOutC, &outC, x);
793 *    return out;
794 * }
795 *
796 * We avoid re-using x and y from the definition because these have reserved
797 * meanings in a .rs file.
798 */
799void PermutationWriter::writeRsSection(set<string>* rsAllocationsGenerated) const {
800    // Write the allocation declarations we'll need.
801    for (auto p : mPermutation.getParams()) {
802        // Don't need allocation for one input and one return value.
803        if (p != mFirstInputParam) {
804            writeRsAllocationDefinition(*p, rsAllocationsGenerated);
805        }
806    }
807    *mRs << "\n";
808
809    // Write the function header.
810    if (mReturnParam) {
811        *mRs << mReturnParam->rsType;
812    } else {
813        *mRs << "void";
814    }
815    *mRs << " __attribute__((kernel)) " << mRsKernelName;
816    *mRs << "(";
817    bool needComma = false;
818    if (mFirstInputParam) {
819        *mRs << mFirstInputParam->rsType << " " << mFirstInputParam->variableName;
820        needComma = true;
821    }
822    if (mPermutation.getOutputCount() > 1 || mPermutation.getInputCount() > 1) {
823        if (needComma) {
824            *mRs << ", ";
825        }
826        *mRs << "unsigned int x";
827    }
828    *mRs << ")";
829    mRs->startBlock();
830
831    // Write the local variable declarations and initializations.
832    for (auto p : mPermutation.getParams()) {
833        if (p == mFirstInputParam) {
834            continue;
835        }
836        mRs->indent() << p->rsType << " " << p->variableName;
837        if (p->isOutParameter) {
838            *mRs << " = 0;\n";
839        } else {
840            *mRs << " = rsGetElementAt_" << p->rsType << "(" << p->rsAllocName << ", x);\n";
841        }
842    }
843
844    // Write the function call.
845    if (mReturnParam) {
846        if (mPermutation.getOutputCount() > 1) {
847            mRs->indent() << mReturnParam->rsType << " " << mReturnParam->variableName << " = ";
848        } else {
849            mRs->indent() << "return ";
850        }
851    }
852    *mRs << mPermutation.getName() << "(";
853    needComma = false;
854    for (auto p : mPermutation.getParams()) {
855        if (needComma) {
856            *mRs << ", ";
857        }
858        if (p->isOutParameter) {
859            *mRs << "&";
860        }
861        *mRs << p->variableName;
862        needComma = true;
863    }
864    *mRs << ");\n";
865
866    if (mPermutation.getOutputCount() > 1) {
867        // Write setting the extra out parameters into the allocations.
868        for (auto p : mPermutation.getParams()) {
869            if (p->isOutParameter) {
870                mRs->indent() << "rsSetElementAt_" << p->rsType << "(" << p->rsAllocName << ", ";
871                // Check if we need to use '&' for this type of argument.
872                char lastChar = p->variableName.back();
873                if (lastChar >= '0' && lastChar <= '9') {
874                    *mRs << "&";
875                }
876                *mRs << p->variableName << ", x);\n";
877            }
878        }
879        if (mReturnParam) {
880            mRs->indent() << "return " << mReturnParam->variableName << ";\n";
881        }
882    }
883    mRs->endBlock();
884}
885
886void PermutationWriter::writeRsAllocationDefinition(const ParameterDefinition& param,
887                                                    set<string>* rsAllocationsGenerated) const {
888    if (!testAndSet(param.rsAllocName, rsAllocationsGenerated)) {
889        *mRs << "rs_allocation " << param.rsAllocName << ";\n";
890    }
891}
892
893// Open the mJavaFile and writes the header.
894static bool startJavaFile(GeneratedFile* file, const Function& function, const string& directory,
895                          const string& testName, const string& relaxedTestName) {
896    const string fileName = testName + ".java";
897    if (!file->start(directory, fileName)) {
898        return false;
899    }
900    file->writeNotices();
901
902    *file << "package android.renderscript.cts;\n\n";
903
904    *file << "import android.renderscript.Allocation;\n";
905    *file << "import android.renderscript.RSRuntimeException;\n";
906    *file << "import android.renderscript.Element;\n\n";
907    *file << "import java.util.Arrays;\n\n";
908
909    *file << "public class " << testName << " extends RSBaseCompute";
910    file->startBlock();  // The corresponding endBlock() is in finishJavaFile()
911    *file << "\n";
912
913    file->indent() << "private ScriptC_" << testName << " script;\n";
914    file->indent() << "private ScriptC_" << relaxedTestName << " scriptRelaxed;\n\n";
915
916    file->indent() << "@Override\n";
917    file->indent() << "protected void setUp() throws Exception";
918    file->startBlock();
919
920    file->indent() << "super.setUp();\n";
921    file->indent() << "script = new ScriptC_" << testName << "(mRS);\n";
922    file->indent() << "scriptRelaxed = new ScriptC_" << relaxedTestName << "(mRS);\n";
923
924    file->endBlock();
925    *file << "\n";
926    return true;
927}
928
929// Write the test method that calls all the generated Check methods.
930static void finishJavaFile(GeneratedFile* file, const Function& function,
931                           const vector<string>& javaCheckMethods) {
932    file->indent() << "public void test" << function.getCapitalizedName() << "()";
933    file->startBlock();
934    for (auto m : javaCheckMethods) {
935        file->indent() << m << "();\n";
936    }
937    file->endBlock();
938
939    file->endBlock();
940}
941
942// Open the script file and write its header.
943static bool startRsFile(GeneratedFile* file, const Function& function, const string& directory,
944                        const string& testName) {
945    string fileName = testName + ".rs";
946    if (!file->start(directory, fileName)) {
947        return false;
948    }
949    file->writeNotices();
950
951    *file << "#pragma version(1)\n";
952    *file << "#pragma rs java_package_name(android.renderscript.cts)\n\n";
953    return true;
954}
955
956// Write the entire *Relaxed.rs test file, as it only depends on the name.
957static bool writeRelaxedRsFile(const Function& function, const string& directory,
958                               const string& testName, const string& relaxedTestName) {
959    string name = relaxedTestName + ".rs";
960
961    GeneratedFile file;
962    if (!file.start(directory, name)) {
963        return false;
964    }
965    file.writeNotices();
966
967    file << "#include \"" << testName << ".rs\"\n";
968    file << "#pragma rs_fp_relaxed\n";
969    file.close();
970    return true;
971}
972
973/* Write the .java and the two .rs test files.  versionOfTestFiles is used to restrict which API
974 * to test.
975 */
976static bool writeTestFilesForFunction(const Function& function, const string& directory,
977                                      int versionOfTestFiles) {
978    // Avoid creating empty files if we're not testing this function.
979    if (!needTestFiles(function, versionOfTestFiles)) {
980        return true;
981    }
982
983    const string testName = "Test" + function.getCapitalizedName();
984    const string relaxedTestName = testName + "Relaxed";
985
986    if (!writeRelaxedRsFile(function, directory, testName, relaxedTestName)) {
987        return false;
988    }
989
990    GeneratedFile rsFile;    // The Renderscript test file we're generating.
991    GeneratedFile javaFile;  // The Jave test file we're generating.
992    if (!startRsFile(&rsFile, function, directory, testName)) {
993        return false;
994    }
995
996    if (!startJavaFile(&javaFile, function, directory, testName, relaxedTestName)) {
997        return false;
998    }
999
1000    /* We keep track of the allocations generated in the .rs file and the argument classes defined
1001     * in the Java file, as we share these between the functions created for each specification.
1002     */
1003    set<string> rsAllocationsGenerated;
1004    set<string> javaGeneratedArgumentClasses;
1005    // Lines of Java code to invoke the check methods.
1006    vector<string> javaCheckMethods;
1007
1008    for (auto spec : function.getSpecifications()) {
1009        if (spec->hasTests(versionOfTestFiles)) {
1010            for (auto permutation : spec->getPermutations()) {
1011                PermutationWriter w(*permutation, &rsFile, &javaFile);
1012                w.writeRsSection(&rsAllocationsGenerated);
1013                w.writeJavaSection(&javaGeneratedArgumentClasses);
1014
1015                // Store the check method to be called.
1016                javaCheckMethods.push_back(w.getJavaCheckMethodName());
1017            }
1018        }
1019    }
1020
1021    finishJavaFile(&javaFile, function, javaCheckMethods);
1022    // There's no work to wrap-up in the .rs file.
1023
1024    rsFile.close();
1025    javaFile.close();
1026    return true;
1027}
1028
1029bool generateTestFiles(const string& directory, int versionOfTestFiles) {
1030    bool success = true;
1031    for (auto f : systemSpecification.getFunctions()) {
1032        if (!writeTestFilesForFunction(*f.second, directory, versionOfTestFiles)) {
1033            success = false;
1034        }
1035    }
1036    return success;
1037}
1038