1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Top level driver for models and examples generated by test_generator.py 18 19#include "Bridge.h" 20#include "NeuralNetworksWrapper.h" 21#include "TestHarness.h" 22 23#include <gtest/gtest.h> 24#include <cassert> 25#include <cmath> 26#include <fstream> 27#include <iostream> 28#include <map> 29 30// Uncomment the following line to generate DOT graphs. 31// 32// #define GRAPH GRAPH 33 34namespace generated_tests { 35using namespace android::nn::wrapper; 36using namespace test_helper; 37 38void graphDump([[maybe_unused]] const char* name, [[maybe_unused]] const Model& model) { 39#ifdef GRAPH 40 ::android::nn::bridge_tests::graphDump( 41 name, 42 reinterpret_cast<const ::android::nn::ModelBuilder*>(model.getHandle())); 43#endif 44} 45 46template <typename T> 47static void print(std::ostream& os, const MixedTyped& test) { 48 // dump T-typed inputs 49 for_each<T>(test, [&os](int idx, const std::vector<T>& f) { 50 os << " aliased_output" << idx << ": ["; 51 for (size_t i = 0; i < f.size(); ++i) { 52 os << (i == 0 ? "" : ", ") << +f[i]; 53 } 54 os << "],\n"; 55 }); 56} 57 58static void printAll(std::ostream& os, const MixedTyped& test) { 59 print<float>(os, test); 60 print<int32_t>(os, test); 61 print<uint8_t>(os, test); 62} 63 64// Test driver for those generated from ml/nn/runtime/test/spec 65static void execute(std::function<void(Model*)> createModel, 66 std::function<bool(int)> isIgnored, 67 std::vector<MixedTypedExampleType>& examples, 68 std::string dumpFile = "") { 69 Model model; 70 createModel(&model); 71 model.finish(); 72 graphDump("", model); 73 bool dumpToFile = !dumpFile.empty(); 74 75 std::ofstream s; 76 if (dumpToFile) { 77 s.open(dumpFile, std::ofstream::trunc); 78 ASSERT_TRUE(s.is_open()); 79 } 80 81 int exampleNo = 0; 82 Compilation compilation(&model); 83 compilation.finish(); 84 85 // If in relaxed mode, set the error range to be 5ULP of FP16. 86 float fpRange = !model.isRelaxed() ? 1e-5f : 5.0f * 0.0009765625f; 87 for (auto& example : examples) { 88 SCOPED_TRACE(exampleNo); 89 // TODO: We leave it as a copy here. 90 // Should verify if the input gets modified by the test later. 91 MixedTyped inputs = example.first; 92 const MixedTyped& golden = example.second; 93 94 Execution execution(&compilation); 95 96 // Set all inputs 97 for_all(inputs, [&execution](int idx, const void* p, size_t s) { 98 const void* buffer = s == 0 ? nullptr : p; 99 ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, buffer, s)); 100 }); 101 102 MixedTyped test; 103 // Go through all typed outputs 104 resize_accordingly(golden, test); 105 for_all(test, [&execution](int idx, void* p, size_t s) { 106 void* buffer = s == 0 ? nullptr : p; 107 ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, buffer, s)); 108 }); 109 110 Result r = execution.compute(); 111 ASSERT_EQ(Result::NO_ERROR, r); 112 113 // Dump all outputs for the slicing tool 114 if (dumpToFile) { 115 s << "output" << exampleNo << " = {\n"; 116 printAll(s, test); 117 // all outputs are done 118 s << "}\n"; 119 } 120 121 // Filter out don't cares 122 MixedTyped filteredGolden = filter(golden, isIgnored); 123 MixedTyped filteredTest = filter(test, isIgnored); 124 // We want "close-enough" results for float 125 126 compare(filteredGolden, filteredTest, fpRange); 127 exampleNo++; 128 } 129} 130 131}; // namespace generated_tests 132 133using namespace android::nn::wrapper; 134 135// Mixed-typed examples 136typedef test_helper::MixedTypedExampleType MixedTypedExample; 137 138class GeneratedTests : public ::testing::Test { 139protected: 140 virtual void SetUp() {} 141}; 142 143// Testcases generated from runtime/test/specs/*.mod.py 144using namespace test_helper; 145using namespace generated_tests; 146#include "generated/all_generated_tests.cpp" 147// End of testcases generated from runtime/test/specs/*.mod.py 148