TestGenerated.cpp revision 1fe4cea7a58045eb715cade0535220b5fb7f658d
1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Top level driver for models and examples generated by test_generator.py 18 19#include "NeuralNetworksWrapper.h" 20#include "TestHarness.h" 21 22//#include <android-base/logging.h> 23#include <gtest/gtest.h> 24#include <cassert> 25#include <cmath> 26#include <iostream> 27#include <map> 28 29namespace generated_tests { 30using namespace android::nn::wrapper; 31 32template <typename T> 33class Example { 34 public: 35 typedef T ElementType; 36 typedef std::pair<std::map<int, std::vector<T>>, 37 std::map<int, std::vector<T>>> 38 ExampleType; 39 40 static bool Execute(std::function<void(Model*)> create_model, 41 std::vector<ExampleType>& examples, 42 std::function<bool(const T, const T)> compare) { 43 Model model; 44 create_model(&model); 45 model.finish(); 46 47 int example_no = 1; 48 bool error = false; 49 for (auto& example : examples) { 50 Compilation compilation(&model); 51 compilation.finish(); 52 Execution execution(&compilation); 53 54 // Go through all inputs 55 for (auto& i : example.first) { 56 std::vector<T>& input = i.second; 57 execution.setInput(i.first, (const void*)input.data(), 58 input.size() * sizeof(T)); 59 } 60 61 std::map<int, std::vector<T>> test_outputs; 62 63 assert(example.second.size() == 1); 64 int output_no = 0; 65 for (auto& i : example.second) { 66 std::vector<T>& output = i.second; 67 test_outputs[i.first].resize(output.size()); 68 std::vector<T>& test_output = test_outputs[i.first]; 69 execution.setOutput(output_no++, (void*)test_output.data(), 70 test_output.size() * sizeof(T)); 71 } 72 Result r = execution.compute(); 73 if (r != Result::NO_ERROR) 74 std::cerr << "Execution was not completed normally\n"; 75 bool mismatch = false; 76 for (auto& i : example.second) { 77 const std::vector<T>& test = test_outputs[i.first]; 78 const std::vector<T>& golden = i.second; 79 for (unsigned i = 0; i < golden.size(); i++) { 80 if (compare(golden[i], test[i])) { 81 std::cerr << " output[" << i << "] = " << (float)test[i] 82 << " (should be " << (float)golden[i] 83 << ")\n"; 84 error = error || true; 85 mismatch = mismatch || true; 86 } 87 } 88 } 89 if (mismatch) { 90 std::cerr << "Example: " << example_no++; 91 std::cerr << " failed\n"; 92 } 93 } 94 return error; 95 } 96 97 // Test driver for those generated from ml/nn/runtime/test/spec 98 static void Execute(std::function<void(Model*)> create_model, 99 std::function<bool(int)> is_ignored, 100 std::vector<MixedTypedExampleType>& examples) { 101 Model model; 102 create_model(&model); 103 model.finish(); 104 105 int example_no = 1; 106 for (auto& example : examples) { 107 SCOPED_TRACE(example_no++); 108 MixedTyped inputs = example.first; 109 const MixedTyped &golden = example.second; 110 111 Compilation compilation(&model); 112 compilation.finish(); 113 Execution execution(&compilation); 114 115 // Go through all ty-typed inputs 116 for_all(inputs, [&execution](int idx, auto p, auto s) { 117 ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, p, s)); 118 }); 119 120 MixedTyped test; 121 // Go through all typed outputs 122 resize_accordingly<float>(golden, test); 123 resize_accordingly<int32_t>(golden, test); 124 resize_accordingly<uint8_t>(golden, test); 125 for_all(test, [&execution](int idx, void* p, auto s) { 126 ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, p, s)); 127 }); 128 129 Result r = execution.compute(); 130 ASSERT_EQ(Result::NO_ERROR, r); 131 // Filter out don't cares 132 MixedTyped filtered_golden; 133 MixedTyped filtered_test; 134 filter<float>(golden, &filtered_golden, is_ignored); 135 filter<float>(test, &filtered_test, is_ignored); 136 filter<int32_t>(golden, &filtered_golden, is_ignored); 137 filter<int32_t>(test, &filtered_test, is_ignored); 138 filter<uint8_t>(golden, &filtered_golden, is_ignored); 139 filter<uint8_t>(test, &filtered_test, is_ignored); 140#define USE_EXPECT_FLOAT_EQ 1 141#ifdef USE_EXPECT_FLOAT_EQ 142 // We want "close-enough" results for float 143 for_each<float>(filtered_golden, 144 [&filtered_test](int index, auto& m) { 145 auto& test_float_operands = 146 std::get<Float32Operands>(filtered_test); 147 auto& test_float = test_float_operands[index]; 148 for (unsigned int i = 0; i < m.size(); i++) { 149 SCOPED_TRACE(i); 150 EXPECT_NEAR(m[i], test_float[i], 1.e-5); 151 } 152 }); 153#else // Use EXPECT_EQ instead; nicer error reporting 154 EXPECT_EQ(std::get<Float32Operands>(filtered_golden), 155 std::get<Float32Operands>(filtered_test)); 156#endif 157 EXPECT_EQ(std::get<Int32Operands>(filtered_golden), 158 std::get<Int32Operands>(filtered_test)); 159 EXPECT_EQ(std::get<Quant8Operands>(filtered_golden), 160 std::get<Quant8Operands>(filtered_test)); 161 } 162 } 163}; 164}; // namespace generated_tests 165 166using namespace android::nn::wrapper; 167// Float32 examples 168typedef generated_tests::Example<float>::ExampleType Example; 169// Mixed-typed examples 170typedef generated_tests::MixedTypedExampleType MixedTypedExample; 171 172void Execute(std::function<void(Model*)> create_model, 173 std::function<bool(int)> is_ignored, 174 std::vector<MixedTypedExample>& examples) { 175 generated_tests::Example<float>::Execute(create_model, 176 is_ignored, examples); 177} 178 179class GeneratedTests : public ::testing::Test { 180 protected: 181 virtual void SetUp() { 182 // For detailed logs, uncomment this line: 183 // SetMinimumLogSeverity(android::base::VERBOSE); 184 } 185}; 186 187// Testcases generated from runtime/test/specs/*.mod.py 188#include "generated/all_generated_tests.cpp" 189// End of testcases generated from runtime/test/specs/*.mod.py 190 191// Below are testcases geneated from TFLite testcases. 192namespace conv_1_h3_w2_SAME { 193std::vector<Example> examples = { 194// Converted examples 195#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc" 196}; 197// Generated model constructor 198#include "generated/models/conv_1_h3_w2_SAME.model.cpp" 199} // namespace conv_1_h3_w2_SAME 200 201namespace conv_1_h3_w2_VALID { 202std::vector<Example> examples = { 203// Converted examples 204#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc" 205}; 206// Generated model constructor 207#include "generated/models/conv_1_h3_w2_VALID.model.cpp" 208} // namespace conv_1_h3_w2_VALID 209 210namespace conv_3_h3_w2_SAME { 211std::vector<Example> examples = { 212// Converted examples 213#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc" 214}; 215// Generated model constructor 216#include "generated/models/conv_3_h3_w2_SAME.model.cpp" 217} // namespace conv_3_h3_w2_SAME 218 219namespace conv_3_h3_w2_VALID { 220std::vector<Example> examples = { 221// Converted examples 222#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc" 223}; 224// Generated model constructor 225#include "generated/models/conv_3_h3_w2_VALID.model.cpp" 226} // namespace conv_3_h3_w2_VALID 227 228namespace depthwise_conv { 229std::vector<Example> examples = { 230// Converted examples 231#include "generated/examples/depthwise_conv_tests.example.cc" 232}; 233// Generated model constructor 234#include "generated/models/depthwise_conv.model.cpp" 235} // namespace depthwise_conv 236 237namespace mobilenet { 238std::vector<Example> examples = { 239// Converted examples 240#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc" 241}; 242// Generated model constructor 243#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp" 244} // namespace mobilenet 245 246namespace { 247bool Execute(std::function<void(Model*)> create_model, 248 std::vector<Example>& examples) { 249 return generated_tests::Example<float>::Execute( 250 create_model, examples, [](float golden, float test) { 251 return std::fabs(golden - test) > 1.5e-5f; 252 }); 253} 254} // namespace 255 256TEST_F(GeneratedTests, conv_1_h3_w2_SAME) { 257 ASSERT_EQ( 258 Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples), 259 0); 260} 261 262TEST_F(GeneratedTests, conv_1_h3_w2_VALID) { 263 ASSERT_EQ( 264 Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples), 265 0); 266} 267 268TEST_F(GeneratedTests, conv_3_h3_w2_SAME) { 269 ASSERT_EQ( 270 Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples), 271 0); 272} 273 274TEST_F(GeneratedTests, conv_3_h3_w2_VALID) { 275 ASSERT_EQ( 276 Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples), 277 0); 278} 279 280TEST_F(GeneratedTests, depthwise_conv) { 281 ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples), 282 0); 283} 284 285TEST_F(GeneratedTests, mobilenet) { 286 ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0); 287} 288