1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Top level driver for models and examples generated by test_generator.py
18
19#include "Bridge.h"
20#include "NeuralNetworksWrapper.h"
21#include "TestHarness.h"
22
23#include <gtest/gtest.h>
24#include <cassert>
25#include <cmath>
26#include <fstream>
27#include <iostream>
28#include <map>
29
30// Uncomment the following line to generate DOT graphs.
31//
32// #define GRAPH GRAPH
33
34namespace generated_tests {
35using namespace android::nn::wrapper;
36using namespace test_helper;
37
38void graphDump([[maybe_unused]] const char* name, [[maybe_unused]] const Model& model) {
39#ifdef GRAPH
40    ::android::nn::bridge_tests::graphDump(
41         name,
42         reinterpret_cast<const ::android::nn::ModelBuilder*>(model.getHandle()));
43#endif
44}
45
46template <typename T>
47static void print(std::ostream& os, const MixedTyped& test) {
48    // dump T-typed inputs
49    for_each<T>(test, [&os](int idx, const std::vector<T>& f) {
50        os << "    aliased_output" << idx << ": [";
51        for (size_t i = 0; i < f.size(); ++i) {
52            os << (i == 0 ? "" : ", ") << +f[i];
53        }
54        os << "],\n";
55    });
56}
57
58static void printAll(std::ostream& os, const MixedTyped& test) {
59    print<float>(os, test);
60    print<int32_t>(os, test);
61    print<uint8_t>(os, test);
62}
63
64// Test driver for those generated from ml/nn/runtime/test/spec
65static void execute(std::function<void(Model*)> createModel,
66             std::function<bool(int)> isIgnored,
67             std::vector<MixedTypedExampleType>& examples,
68             std::string dumpFile = "") {
69    Model model;
70    createModel(&model);
71    model.finish();
72    graphDump("", model);
73    bool dumpToFile = !dumpFile.empty();
74
75    std::ofstream s;
76    if (dumpToFile) {
77        s.open(dumpFile, std::ofstream::trunc);
78        ASSERT_TRUE(s.is_open());
79    }
80
81    int exampleNo = 0;
82    Compilation compilation(&model);
83    compilation.finish();
84
85    // If in relaxed mode, set the error range to be 5ULP of FP16.
86    float fpRange = !model.isRelaxed() ? 1e-5f : 5.0f * 0.0009765625f;
87    for (auto& example : examples) {
88        SCOPED_TRACE(exampleNo);
89        // TODO: We leave it as a copy here.
90        // Should verify if the input gets modified by the test later.
91        MixedTyped inputs = example.first;
92        const MixedTyped& golden = example.second;
93
94        Execution execution(&compilation);
95
96        // Set all inputs
97        for_all(inputs, [&execution](int idx, const void* p, size_t s) {
98            const void* buffer = s == 0 ? nullptr : p;
99            ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, buffer, s));
100        });
101
102        MixedTyped test;
103        // Go through all typed outputs
104        resize_accordingly(golden, test);
105        for_all(test, [&execution](int idx, void* p, size_t s) {
106            void* buffer = s == 0 ? nullptr : p;
107            ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, buffer, s));
108        });
109
110        Result r = execution.compute();
111        ASSERT_EQ(Result::NO_ERROR, r);
112
113        // Dump all outputs for the slicing tool
114        if (dumpToFile) {
115            s << "output" << exampleNo << " = {\n";
116            printAll(s, test);
117            // all outputs are done
118            s << "}\n";
119        }
120
121        // Filter out don't cares
122        MixedTyped filteredGolden = filter(golden, isIgnored);
123        MixedTyped filteredTest = filter(test, isIgnored);
124        // We want "close-enough" results for float
125
126        compare(filteredGolden, filteredTest, fpRange);
127        exampleNo++;
128    }
129}
130
131};  // namespace generated_tests
132
133using namespace android::nn::wrapper;
134
135// Mixed-typed examples
136typedef test_helper::MixedTypedExampleType MixedTypedExample;
137
138class GeneratedTests : public ::testing::Test {
139protected:
140    virtual void SetUp() {}
141};
142
143// Testcases generated from runtime/test/specs/*.mod.py
144using namespace test_helper;
145using namespace generated_tests;
146#include "generated/all_generated_tests.cpp"
147// End of testcases generated from runtime/test/specs/*.mod.py
148