TestGenerated.cpp revision 663155d58ca1e1eb42e01495e236258e0c00d40f
1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Top level driver for models and examples generated by test_generator.py 18 19#include "Bridge.h" 20#include "NeuralNetworksWrapper.h" 21#include "TestHarness.h" 22 23#include <gtest/gtest.h> 24#include <cassert> 25#include <cmath> 26#include <fstream> 27#include <iostream> 28#include <map> 29 30// Uncomment the following line to generate DOT graphs. 31// 32// #define GRAPH GRAPH 33 34namespace generated_tests { 35using namespace android::nn::wrapper; 36 37void graphDump([[maybe_unused]] const char* name, [[maybe_unused]] const Model& model) { 38#ifdef GRAPH 39 ::android::nn::bridge_tests::graphDump( 40 name, 41 reinterpret_cast<const ::android::nn::ModelBuilder*>(model.getHandle())); 42#endif 43} 44 45template <typename T> 46static void print(std::ostream& os, const MixedTyped& test) { 47 // dump T-typed inputs 48 for_each<T>(test, [&os](int idx, const std::vector<T>& f) { 49 os << " aliased_output" << idx << ": ["; 50 for (size_t i = 0; i < f.size(); ++i) { 51 os << (i == 0 ? "" : ", ") << +f[i]; 52 } 53 os << "],\n"; 54 }); 55} 56 57static void printAll(std::ostream& os, const MixedTyped& test) { 58 print<float>(os, test); 59 print<int32_t>(os, test); 60 print<uint8_t>(os, test); 61} 62 63// Test driver for those generated from ml/nn/runtime/test/spec 64static void execute(std::function<void(Model*)> createModel, 65 std::function<bool(int)> isIgnored, 66 std::vector<MixedTypedExampleType>& examples, 67 std::string dumpFile = "") { 68 Model model; 69 createModel(&model); 70 model.finish(); 71 graphDump("", model); 72 bool dumpToFile = !dumpFile.empty(); 73 74 std::ofstream s; 75 if (dumpToFile) { 76 s.open(dumpFile, std::ofstream::trunc); 77 ASSERT_TRUE(s.is_open()); 78 } 79 80 int exampleNo = 0; 81 Compilation compilation(&model); 82 compilation.finish(); 83 for (auto& example : examples) { 84 SCOPED_TRACE(exampleNo); 85 // TODO: We leave it as a copy here. 86 // Should verify if the input gets modified by the test later. 87 MixedTyped inputs = example.first; 88 const MixedTyped& golden = example.second; 89 90 Execution execution(&compilation); 91 92 // Set all inputs 93 for_all(inputs, [&execution](int idx, const void* p, size_t s) { 94 ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, p, s)); 95 }); 96 97 MixedTyped test; 98 // Go through all typed outputs 99 resize_accordingly(golden, test); 100 for_all(test, [&execution](int idx, void* p, size_t s) { 101 ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, p, s)); 102 }); 103 104 Result r = execution.compute(); 105 ASSERT_EQ(Result::NO_ERROR, r); 106 107 // Dump all outputs for the slicing tool 108 if (dumpToFile) { 109 s << "output" << exampleNo << " = {\n"; 110 printAll(s, test); 111 // all outputs are done 112 s << "}\n"; 113 } 114 115 // Filter out don't cares 116 MixedTyped filteredGolden = filter(golden, isIgnored); 117 MixedTyped filteredTest = filter(test, isIgnored); 118 // We want "close-enough" results for float 119 compare(filteredGolden, filteredTest); 120 exampleNo++; 121 } 122} 123 124}; // namespace generated_tests 125 126using namespace android::nn::wrapper; 127 128// Mixed-typed examples 129typedef generated_tests::MixedTypedExampleType MixedTypedExample; 130 131class GeneratedTests : public ::testing::Test { 132 protected: 133 virtual void SetUp() {} 134}; 135 136// Testcases generated from runtime/test/specs/*.mod.py 137using namespace generated_tests; 138#include "generated/all_generated_tests.cpp" 139// End of testcases generated from runtime/test/specs/*.mod.py 140