TestGenerated.cpp revision 544739620cd7f37d40524d2407c92042e485c73f
1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Top level driver for models and examples generated by test_generator.py 18 19#include "NeuralNetworksWrapper.h" 20#include "TestHarness.h" 21 22#include <gtest/gtest.h> 23#include <cassert> 24#include <cmath> 25#include <functional> 26#include <iostream> 27#include <map> 28 29namespace generated_tests { 30using namespace android::nn::wrapper; 31 32template <typename T> 33class Example { 34 public: 35 typedef T ElementType; 36 typedef std::pair<std::map<int, std::vector<T>>, 37 std::map<int, std::vector<T>>> 38 ExampleType; 39 40 static bool Execute(std::function<void(Model*)> create_model, 41 std::vector<ExampleType>& examples, 42 std::function<bool(const T, const T)> compare) { 43 Model model; 44 create_model(&model); 45 model.finish(); 46 47 int example_no = 1; 48 bool error = false; 49 for (auto& example : examples) { 50 Compilation compilation(&model); 51 compilation.compile(); 52 Request request(&compilation); 53 54 // Go through all inputs 55 for (auto& i : example.first) { 56 std::vector<T>& input = i.second; 57 request.setInput(i.first, (const void*)input.data(), 58 input.size() * sizeof(T)); 59 } 60 61 std::map<int, std::vector<T>> test_outputs; 62 63 assert(example.second.size() == 1); 64 int output_no = 0; 65 for (auto& i : example.second) { 66 std::vector<T>& output = i.second; 67 test_outputs[i.first].resize(output.size()); 68 std::vector<T>& test_output = test_outputs[i.first]; 69 request.setOutput(output_no++, (void*)test_output.data(), 70 test_output.size() * sizeof(T)); 71 } 72 Result r = request.compute(); 73 if (r != Result::NO_ERROR) 74 std::cerr << "Request was not completed normally\n"; 75 bool mismatch = false; 76 for (auto& i : example.second) { 77 const std::vector<T>& test = test_outputs[i.first]; 78 const std::vector<T>& golden = i.second; 79 for (unsigned i = 0; i < golden.size(); i++) { 80 if (compare(golden[i], test[i])) { 81 std::cerr << " output[" << i << "] = " << (float)test[i] 82 << " (should be " << (float)golden[i] 83 << ")\n"; 84 error = error || true; 85 mismatch = mismatch || true; 86 } 87 } 88 } 89 if (mismatch) { 90 std::cerr << "Example: " << example_no++; 91 std::cerr << " failed\n"; 92 } 93 } 94 return error; 95 } 96 97 // Test driver for those generated from ml/nn/runtime/test/spec 98 static void Execute(std::function<void(Model*)> create_model, 99 std::vector<MixedTypedExampleType>& examples) { 100 Model model; 101 create_model(&model); 102 model.finish(); 103 104 int example_no = 1; 105 for (auto& example : examples) { 106 SCOPED_TRACE(example_no++); 107 108 MixedTyped& inputs = example.first; 109 MixedTyped& golden = example.second; 110 111 Compilation compilation(&model); 112 compilation.compile(); 113 Request request(&compilation); 114 115 // Go through all ty-typed inputs 116#define SET_TYPED_TENSOR_INPUT(ty) \ 117 for (auto& i : std::get<std::map<int, std::vector<ty>>>(inputs)) { \ 118 request.setInput(i.first, (const void*)i.second.data(), \ 119 i.second.size() * sizeof(ty)); \ 120 } 121 122 SET_TYPED_TENSOR_INPUT(float); 123 SET_TYPED_TENSOR_INPUT(int32_t); 124 SET_TYPED_TENSOR_INPUT(uint8_t); 125#undef SET_TYPED_TENSOR_INPUT 126 127 MixedTyped test; 128 // Go through all typed outputs 129#define SET_TYPED_OUTPUT(ty) \ 130 auto& golden_##ty = std::get<std::map<int, std::vector<ty>>>(golden); \ 131 auto& test_##ty = std::get<std::map<int, std::vector<ty>>>(test); \ 132 for (auto& i : golden_##ty) { \ 133 int idx = i.first; \ 134 auto& golden_output = i.second; \ 135 test_##ty[idx].resize(golden_output.size()); \ 136 request.setOutput(idx, (void*)test_##ty[idx].data(), \ 137 test_##ty[idx].size() * sizeof(ty)); \ 138 } 139 SET_TYPED_OUTPUT(float); 140 SET_TYPED_OUTPUT(int32_t); 141 SET_TYPED_OUTPUT(uint8_t); 142#undef SET_TYPED_OUTPUT 143 144 Result r = request.compute(); 145 ASSERT_EQ(Result::NO_ERROR, r); 146 147 EXPECT_EQ(golden_float, test_float); 148 EXPECT_EQ(golden_int32_t, test_int32_t); 149 EXPECT_EQ(golden_uint8_t, test_uint8_t); 150 } 151 } 152}; 153}; // namespace generated_tests 154 155using namespace android::nn::wrapper; 156// Float32 examples 157typedef generated_tests::Example<float>::ExampleType Example; 158// Mixed-typed examples 159typedef generated_tests::MixedTypedExampleType MixedTypedExample; 160 161void Execute(std::function<void(Model*)> create_model, 162 std::vector<MixedTypedExample>& examples) { 163 generated_tests::Example<float>::Execute(create_model, examples); 164} 165 166class GeneratedTests : public ::testing::Test { 167 protected: 168 virtual void SetUp() { 169 ASSERT_EQ(android::nn::wrapper::Initialize(), 170 android::nn::wrapper::Result::NO_ERROR); 171 } 172 173 virtual void TearDown() { android::nn::wrapper::Shutdown(); } 174}; 175 176// Testcases generated from runtime/test/specs/*.mod.py 177#include "generated/all_generated_tests.cpp" 178// End of testcases generated from runtime/test/specs/*.mod.py 179 180// Below are testcases geneated from TFLite testcases. 181namespace conv_1_h3_w2_SAME { 182std::vector<Example> examples = { 183// Converted examples 184#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc" 185}; 186// Generated model constructor 187#include "generated/models/conv_1_h3_w2_SAME.model.cpp" 188} // namespace conv_1_h3_w2_SAME 189 190namespace conv_1_h3_w2_VALID { 191std::vector<Example> examples = { 192// Converted examples 193#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc" 194}; 195// Generated model constructor 196#include "generated/models/conv_1_h3_w2_VALID.model.cpp" 197} // namespace conv_1_h3_w2_VALID 198 199namespace conv_3_h3_w2_SAME { 200std::vector<Example> examples = { 201// Converted examples 202#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc" 203}; 204// Generated model constructor 205#include "generated/models/conv_3_h3_w2_SAME.model.cpp" 206} // namespace conv_3_h3_w2_SAME 207 208namespace conv_3_h3_w2_VALID { 209std::vector<Example> examples = { 210// Converted examples 211#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc" 212}; 213// Generated model constructor 214#include "generated/models/conv_3_h3_w2_VALID.model.cpp" 215} // namespace conv_3_h3_w2_VALID 216 217namespace depthwise_conv { 218std::vector<Example> examples = { 219// Converted examples 220#include "generated/examples/depthwise_conv_tests.example.cc" 221}; 222// Generated model constructor 223#include "generated/models/depthwise_conv.model.cpp" 224} // namespace depthwise_conv 225 226namespace mobilenet { 227std::vector<Example> examples = { 228// Converted examples 229#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc" 230}; 231// Generated model constructor 232#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp" 233} // namespace mobilenet 234 235namespace { 236bool Execute(std::function<void(Model*)> create_model, 237 std::vector<Example>& examples) { 238 return generated_tests::Example<float>::Execute( 239 create_model, examples, [](float golden, float test) { 240 return std::fabs(golden - test) > 1.5e-5f; 241 }); 242} 243} // namespace 244 245 246TEST_F(GeneratedTests, conv_1_h3_w2_SAME) { 247 ASSERT_EQ( 248 Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples), 249 0); 250} 251 252TEST_F(GeneratedTests, conv_1_h3_w2_VALID) { 253 ASSERT_EQ( 254 Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples), 255 0); 256} 257 258TEST_F(GeneratedTests, conv_3_h3_w2_SAME) { 259 ASSERT_EQ( 260 Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples), 261 0); 262} 263 264TEST_F(GeneratedTests, conv_3_h3_w2_VALID) { 265 ASSERT_EQ( 266 Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples), 267 0); 268} 269 270TEST_F(GeneratedTests, depthwise_conv) { 271 ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples), 272 0); 273} 274 275TEST_F(GeneratedTests, mobilenet) { 276 ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0); 277} 278