TestGenerated.cpp revision 663155d58ca1e1eb42e01495e236258e0c00d40f
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Top level driver for models and examples generated by test_generator.py
18
19#include "Bridge.h"
20#include "NeuralNetworksWrapper.h"
21#include "TestHarness.h"
22
23#include <gtest/gtest.h>
24#include <cassert>
25#include <cmath>
26#include <fstream>
27#include <iostream>
28#include <map>
29
30// Uncomment the following line to generate DOT graphs.
31//
32// #define GRAPH GRAPH
33
34namespace generated_tests {
35using namespace android::nn::wrapper;
36
37void graphDump([[maybe_unused]] const char* name, [[maybe_unused]] const Model& model) {
38#ifdef GRAPH
39    ::android::nn::bridge_tests::graphDump(
40         name,
41         reinterpret_cast<const ::android::nn::ModelBuilder*>(model.getHandle()));
42#endif
43}
44
45template <typename T>
46static void print(std::ostream& os, const MixedTyped& test) {
47    // dump T-typed inputs
48    for_each<T>(test, [&os](int idx, const std::vector<T>& f) {
49        os << "    aliased_output" << idx << ": [";
50        for (size_t i = 0; i < f.size(); ++i) {
51            os << (i == 0 ? "" : ", ") << +f[i];
52        }
53        os << "],\n";
54    });
55}
56
57static void printAll(std::ostream& os, const MixedTyped& test) {
58    print<float>(os, test);
59    print<int32_t>(os, test);
60    print<uint8_t>(os, test);
61}
62
63// Test driver for those generated from ml/nn/runtime/test/spec
64static void execute(std::function<void(Model*)> createModel,
65             std::function<bool(int)> isIgnored,
66             std::vector<MixedTypedExampleType>& examples,
67             std::string dumpFile = "") {
68    Model model;
69    createModel(&model);
70    model.finish();
71    graphDump("", model);
72    bool dumpToFile = !dumpFile.empty();
73
74    std::ofstream s;
75    if (dumpToFile) {
76        s.open(dumpFile, std::ofstream::trunc);
77        ASSERT_TRUE(s.is_open());
78    }
79
80    int exampleNo = 0;
81    Compilation compilation(&model);
82    compilation.finish();
83    for (auto& example : examples) {
84        SCOPED_TRACE(exampleNo);
85        // TODO: We leave it as a copy here.
86        // Should verify if the input gets modified by the test later.
87        MixedTyped inputs = example.first;
88        const MixedTyped& golden = example.second;
89
90        Execution execution(&compilation);
91
92        // Set all inputs
93        for_all(inputs, [&execution](int idx, const void* p, size_t s) {
94            ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, p, s));
95        });
96
97        MixedTyped test;
98        // Go through all typed outputs
99        resize_accordingly(golden, test);
100        for_all(test, [&execution](int idx, void* p, size_t s) {
101            ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, p, s));
102        });
103
104        Result r = execution.compute();
105        ASSERT_EQ(Result::NO_ERROR, r);
106
107        // Dump all outputs for the slicing tool
108        if (dumpToFile) {
109            s << "output" << exampleNo << " = {\n";
110            printAll(s, test);
111            // all outputs are done
112            s << "}\n";
113        }
114
115        // Filter out don't cares
116        MixedTyped filteredGolden = filter(golden, isIgnored);
117        MixedTyped filteredTest = filter(test, isIgnored);
118        // We want "close-enough" results for float
119        compare(filteredGolden, filteredTest);
120        exampleNo++;
121    }
122}
123
124};  // namespace generated_tests
125
126using namespace android::nn::wrapper;
127
128// Mixed-typed examples
129typedef generated_tests::MixedTypedExampleType MixedTypedExample;
130
131class GeneratedTests : public ::testing::Test {
132   protected:
133    virtual void SetUp() {}
134};
135
136// Testcases generated from runtime/test/specs/*.mod.py
137using namespace generated_tests;
138#include "generated/all_generated_tests.cpp"
139// End of testcases generated from runtime/test/specs/*.mod.py
140