TestGenerated.cpp revision 62cc2758c1c2d303861e209f26bddcf4d7564b73
1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Top level driver for models and examples generated by test_generator.py 18 19#include "NeuralNetworksWrapper.h" 20#include "TestHarness.h" 21 22#include <gtest/gtest.h> 23#include <cassert> 24#include <cmath> 25#include <iostream> 26#include <map> 27 28namespace generated_tests { 29using namespace android::nn::wrapper; 30 31template <typename T> 32class Example { 33 public: 34 typedef T ElementType; 35 typedef std::pair<std::map<int, std::vector<T>>, 36 std::map<int, std::vector<T>>> 37 ExampleType; 38 39 static bool Execute(std::function<void(Model*)> create_model, 40 std::vector<ExampleType>& examples, 41 std::function<bool(const T, const T)> compare) { 42 Model model; 43 create_model(&model); 44 model.finish(); 45 46 int example_no = 1; 47 bool error = false; 48 for (auto& example : examples) { 49 Compilation compilation(&model); 50 compilation.finish(); 51 Execution execution(&compilation); 52 53 // Go through all inputs 54 for (auto& i : example.first) { 55 std::vector<T>& input = i.second; 56 // We interpret an empty vector as an optional argument 57 // that has been omitted. 58 if (input.size() == 0) { 59 execution.setInput(i.first, nullptr, 0); 60 } else { 61 execution.setInput(i.first, (const void*)input.data(), 62 input.size() * sizeof(T)); 63 } 64 } 65 66 std::map<int, std::vector<T>> test_outputs; 67 68 assert(example.second.size() == 1); 69 int output_no = 0; 70 for (auto& i : example.second) { 71 std::vector<T>& output = i.second; 72 test_outputs[i.first].resize(output.size()); 73 std::vector<T>& test_output = test_outputs[i.first]; 74 execution.setOutput(output_no++, (void*)test_output.data(), 75 test_output.size() * sizeof(T)); 76 } 77 Result r = execution.compute(); 78 if (r != Result::NO_ERROR) 79 std::cerr << "Execution was not completed normally\n"; 80 bool mismatch = false; 81 for (auto& i : example.second) { 82 const std::vector<T>& test = test_outputs[i.first]; 83 const std::vector<T>& golden = i.second; 84 for (unsigned i = 0; i < golden.size(); i++) { 85 if (compare(golden[i], test[i])) { 86 std::cerr << " output[" << i << "] = " << (float)test[i] 87 << " (should be " << (float)golden[i] 88 << ")\n"; 89 error = error || true; 90 mismatch = mismatch || true; 91 } 92 } 93 } 94 if (mismatch) { 95 std::cerr << "Example: " << example_no++; 96 std::cerr << " failed\n"; 97 } 98 } 99 return error; 100 } 101 102 // Test driver for those generated from ml/nn/runtime/test/spec 103 static void Execute(std::function<void(Model*)> create_model, 104 std::function<bool(int)> is_ignored, 105 std::vector<MixedTypedExampleType>& examples) { 106 Model model; 107 create_model(&model); 108 model.finish(); 109 110 int example_no = 1; 111 for (auto& example : examples) { 112 SCOPED_TRACE(example_no++); 113 MixedTyped inputs = example.first; 114 const MixedTyped& golden = example.second; 115 116 Compilation compilation(&model); 117 compilation.finish(); 118 Execution execution(&compilation); 119 120 // Set all inputs 121 for_all(inputs, [&execution](int idx, const void* p, size_t s) { 122 ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, p, s)); 123 }); 124 125 MixedTyped test; 126 // Go through all typed outputs 127 resize_accordingly(golden, test); 128 for_all(test, [&execution](int idx, void* p, size_t s) { 129 ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, p, s)); 130 }); 131 132 Result r = execution.compute(); 133 ASSERT_EQ(Result::NO_ERROR, r); 134 // Filter out don't cares 135 MixedTyped filtered_golden; 136 MixedTyped filtered_test; 137 filter(golden, &filtered_golden, is_ignored); 138 filter(test, &filtered_test, is_ignored); 139 // We want "close-enough" results for float 140 compare(filtered_golden, filtered_test); 141 } 142 } 143}; 144}; // namespace generated_tests 145 146using namespace android::nn::wrapper; 147// Float32 examples 148typedef generated_tests::Example<float>::ExampleType Example; 149// Mixed-typed examples 150typedef generated_tests::MixedTypedExampleType MixedTypedExample; 151 152void Execute(std::function<void(Model*)> create_model, 153 std::function<bool(int)> is_ignored, 154 std::vector<MixedTypedExample>& examples) { 155 generated_tests::Example<float>::Execute(create_model, is_ignored, 156 examples); 157} 158 159class GeneratedTests : public ::testing::Test { 160 protected: 161 virtual void SetUp() { 162 // For detailed logs, uncomment this line: 163 // SetMinimumLogSeverity(android::base::VERBOSE); 164 } 165}; 166 167// Testcases generated from runtime/test/specs/*.mod.py 168#include "generated/all_generated_tests.cpp" 169// End of testcases generated from runtime/test/specs/*.mod.py 170 171// Below are testcases geneated from TFLite testcases. 172namespace conv_1_h3_w2_SAME { 173std::vector<Example> examples = { 174// Converted examples 175#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc" 176}; 177// Generated model constructor 178#include "generated/models/conv_1_h3_w2_SAME.model.cpp" 179} // namespace conv_1_h3_w2_SAME 180 181namespace conv_1_h3_w2_VALID { 182std::vector<Example> examples = { 183// Converted examples 184#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc" 185}; 186// Generated model constructor 187#include "generated/models/conv_1_h3_w2_VALID.model.cpp" 188} // namespace conv_1_h3_w2_VALID 189 190namespace conv_3_h3_w2_SAME { 191std::vector<Example> examples = { 192// Converted examples 193#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc" 194}; 195// Generated model constructor 196#include "generated/models/conv_3_h3_w2_SAME.model.cpp" 197} // namespace conv_3_h3_w2_SAME 198 199namespace conv_3_h3_w2_VALID { 200std::vector<Example> examples = { 201// Converted examples 202#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc" 203}; 204// Generated model constructor 205#include "generated/models/conv_3_h3_w2_VALID.model.cpp" 206} // namespace conv_3_h3_w2_VALID 207 208namespace depthwise_conv { 209std::vector<Example> examples = { 210// Converted examples 211#include "generated/examples/depthwise_conv_tests.example.cc" 212}; 213// Generated model constructor 214#include "generated/models/depthwise_conv.model.cpp" 215} // namespace depthwise_conv 216 217namespace mobilenet { 218std::vector<Example> examples = { 219// Converted examples 220#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc" 221}; 222// Generated model constructor 223#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp" 224} // namespace mobilenet 225 226namespace { 227bool Execute(std::function<void(Model*)> create_model, 228 std::vector<Example>& examples) { 229 return generated_tests::Example<float>::Execute( 230 create_model, examples, [](float golden, float test) { 231 return std::fabs(golden - test) > 1.5e-5f; 232 }); 233} 234} // namespace 235 236TEST_F(GeneratedTests, conv_1_h3_w2_SAME) { 237 ASSERT_EQ( 238 Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples), 239 0); 240} 241 242TEST_F(GeneratedTests, conv_1_h3_w2_VALID) { 243 ASSERT_EQ( 244 Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples), 245 0); 246} 247 248TEST_F(GeneratedTests, conv_3_h3_w2_SAME) { 249 ASSERT_EQ( 250 Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples), 251 0); 252} 253 254TEST_F(GeneratedTests, conv_3_h3_w2_VALID) { 255 ASSERT_EQ( 256 Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples), 257 0); 258} 259 260TEST_F(GeneratedTests, depthwise_conv) { 261 ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples), 262 0); 263} 264 265TEST_F(GeneratedTests, mobilenet) { 266 ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0); 267} 268