TestGenerated.cpp revision b74d2837ab1687c1a4f913aa5f90a9838efe0add
1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Top level driver for models and examples generated by test_generator.py 18 19#include "Bridge.h" 20#include "NeuralNetworksWrapper.h" 21#include "TestHarness.h" 22 23#include <gtest/gtest.h> 24#include <cassert> 25#include <cmath> 26#include <fstream> 27#include <iostream> 28#include <map> 29 30// Uncomment the following line to generate DOT graphs. 31// 32// #define GRAPH GRAPH 33 34namespace generated_tests { 35using namespace android::nn::wrapper; 36 37void graphDump([[maybe_unused]] const char* name, [[maybe_unused]] const Model& model) { 38#ifdef GRAPH 39 ::android::nn::bridge_tests::graphDump( 40 name, 41 reinterpret_cast<const ::android::nn::ModelBuilder*>(model.getHandle())); 42#endif 43} 44 45template <typename T> 46static void print(std::ostream& os, const MixedTyped& test) { 47 // dump T-typed inputs 48 for_each<T>(test, [&os](int idx, const std::vector<T>& f) { 49 os << " aliased_output" << idx << ": ["; 50 for (size_t i = 0; i < f.size(); ++i) { 51 os << (i == 0 ? "" : ", ") << +f[i]; 52 } 53 os << "],\n"; 54 }); 55} 56 57static void printAll(std::ostream& os, const MixedTyped& test) { 58 print<float>(os, test); 59 print<int32_t>(os, test); 60 print<uint8_t>(os, test); 61} 62 63// Test driver for those generated from ml/nn/runtime/test/spec 64static void execute(std::function<void(Model*)> createModel, 65 std::function<bool(int)> isIgnored, 66 std::vector<MixedTypedExampleType>& examples, 67 std::string dumpFile = "") { 68 Model model; 69 createModel(&model); 70 model.finish(); 71 graphDump("", model); 72 bool dumpToFile = !dumpFile.empty(); 73 74 std::ofstream s; 75 if (dumpToFile) { 76 s.open(dumpFile, std::ofstream::trunc); 77 ASSERT_TRUE(s.is_open()); 78 } 79 80 int exampleNo = 0; 81 Compilation compilation(&model); 82 compilation.finish(); 83 84 // If in relaxed mode, set the error range to be 5ULP of FP16. 85 float fpRange = !model.isRelaxed() ? 1e-5f : 5.0f * 0.0009765625f; 86 for (auto& example : examples) { 87 SCOPED_TRACE(exampleNo); 88 // TODO: We leave it as a copy here. 89 // Should verify if the input gets modified by the test later. 90 MixedTyped inputs = example.first; 91 const MixedTyped& golden = example.second; 92 93 Execution execution(&compilation); 94 95 // Set all inputs 96 for_all(inputs, [&execution](int idx, const void* p, size_t s) { 97 ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, p, s)); 98 }); 99 100 MixedTyped test; 101 // Go through all typed outputs 102 resize_accordingly(golden, test); 103 for_all(test, [&execution](int idx, void* p, size_t s) { 104 ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, p, s)); 105 }); 106 107 Result r = execution.compute(); 108 ASSERT_EQ(Result::NO_ERROR, r); 109 110 // Dump all outputs for the slicing tool 111 if (dumpToFile) { 112 s << "output" << exampleNo << " = {\n"; 113 printAll(s, test); 114 // all outputs are done 115 s << "}\n"; 116 } 117 118 // Filter out don't cares 119 MixedTyped filteredGolden = filter(golden, isIgnored); 120 MixedTyped filteredTest = filter(test, isIgnored); 121 // We want "close-enough" results for float 122 123 compare(filteredGolden, filteredTest, fpRange); 124 exampleNo++; 125 } 126} 127 128}; // namespace generated_tests 129 130using namespace android::nn::wrapper; 131 132// Mixed-typed examples 133typedef generated_tests::MixedTypedExampleType MixedTypedExample; 134 135class GeneratedTests : public ::testing::Test { 136 protected: 137 virtual void SetUp() {} 138}; 139 140// Testcases generated from runtime/test/specs/*.mod.py 141using namespace generated_tests; 142#include "generated/all_generated_tests.cpp" 143// End of testcases generated from runtime/test/specs/*.mod.py 144