TestGenerated.cpp revision 5d58071b9afb2162bcd53525ce1c8610c5ca2cd8
1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Top level driver for models and examples converted from TFLite tests 18 19#include "NeuralNetworksWrapper.h" 20 21#include <gtest/gtest.h> 22#include <cassert> 23#include <cmath> 24#include <iostream> 25#include <map> 26 27typedef std::pair<std::map<int, std::vector<float>>, 28 std::map<int, std::vector<float>>> 29 Example; 30 31using namespace android::nn::wrapper; 32 33namespace add { 34std::vector<Example> examples = { 35// Generated add 36#include "generated/examples/add_tests.example.cc" 37}; 38// Generated model constructor 39#include "generated/models/add.model.cpp" 40} // add 41 42namespace conv_1_h3_w2_SAME { 43std::vector<Example> examples = { 44// Converted examples 45#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc" 46}; 47// Generated model constructor 48#include "generated/models/conv_1_h3_w2_SAME.model.cpp" 49} // namespace conv_1_h3_w2_SAME 50 51namespace conv_1_h3_w2_VALID { 52std::vector<Example> examples = { 53// Converted examples 54#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc" 55}; 56// Generated model constructor 57#include "generated/models/conv_1_h3_w2_VALID.model.cpp" 58} // namespace conv_1_h3_w2_VALID 59 60namespace conv_3_h3_w2_SAME { 61std::vector<Example> examples = { 62// Converted examples 63#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc" 64}; 65// Generated model constructor 66#include "generated/models/conv_3_h3_w2_SAME.model.cpp" 67} // namespace conv_3_h3_w2_SAME 68 69namespace conv_3_h3_w2_VALID { 70std::vector<Example> examples = { 71// Converted examples 72#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc" 73}; 74// Generated model constructor 75#include "generated/models/conv_3_h3_w2_VALID.model.cpp" 76} // namespace conv_3_h3_w2_VALID 77 78namespace depthwise_conv { 79std::vector<Example> examples = { 80// Converted examples 81#include "generated/examples/depthwise_conv_tests.example.cc" 82}; 83// Generated model constructor 84#include "generated/models/depthwise_conv.model.cpp" 85} // namespace depthwise_conv 86 87namespace avg_pool_float { 88std::vector<Example> examples = { 89// Generated avg_pool float 90#include "generated/examples/avg_pool_float_tests.example.cc" 91}; 92// Generated model constructor 93#include "generated/models/avg_pool_float.model.cpp" 94} // avg_pool_float 95 96namespace max_pool_float { 97std::vector<Example> examples = { 98// Generated max_pool float 99#include "generated/examples/max_pool_float_tests.example.cc" 100}; 101// Generated model constructor 102#include "generated/models/max_pool_float.model.cpp" 103} // max_pool_float 104 105namespace l2_pool_float { 106std::vector<Example> examples = { 107// Generated l2_pool float 108#include "generated/examples/l2_pool_float_tests.example.cc" 109}; 110// Generated model constructor 111#include "generated/models/l2_pool_float.model.cpp" 112} // l2_pool_float 113 114namespace relu_float { 115std::vector<Example> examples = { 116// Generated relu float 117#include "generated/examples/relu_float_tests.example.cc" 118}; 119// Generated model constructor 120#include "generated/models/relu_float.model.cpp" 121} // relu_float 122 123namespace relu1_float { 124std::vector<Example> examples = { 125// Generated relu1 float 126#include "generated/examples/relu1_float_tests.example.cc" 127}; 128// Generated model constructor 129#include "generated/models/relu1_float.model.cpp" 130} // relu1_float 131 132namespace relu6_float { 133std::vector<Example> examples = { 134// Generated relu6 float 135#include "generated/examples/relu6_float_tests.example.cc" 136}; 137// Generated model constructor 138#include "generated/models/relu6_float.model.cpp" 139} // relu6_float 140 141namespace mobilenet { 142std::vector<Example> examples = { 143// Converted examples 144#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc" 145}; 146// Generated model constructor 147#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp" 148} // namespace mobilenet 149 150namespace { 151bool Execute(void (*create_model)(Model*), std::vector<Example>& examples) { 152 Model model; 153 create_model(&model); 154 155 int example_no = 1; 156 bool error = false; 157 158 for (auto& example : examples) { 159 Request request(&model); 160 161 // Go through all inputs 162 for (auto& i : example.first) { 163 std::vector<float>& input = i.second; 164 request.setInput(i.first, (const void*)input.data(), 165 input.size() * sizeof(float)); 166 } 167 168 std::map<int, std::vector<float>> test_outputs; 169 170 assert(example.second.size() == 1); 171 int output_no = 0; 172 for (auto& i : example.second) { 173 std::vector<float>& output = i.second; 174 test_outputs[i.first].resize(output.size()); 175 std::vector<float>& test_output = test_outputs[i.first]; 176 request.setOutput(output_no++, (void*)test_output.data(), 177 test_output.size() * sizeof(float)); 178 } 179 Result r = request.compute(); 180 if (r != Result::NO_ERROR) 181 std::cerr << "Request was not completed normally\n"; 182 bool mismatch = false; 183 for (auto& i : example.second) { 184 std::vector<float>& test = test_outputs[i.first]; 185 std::vector<float>& golden = i.second; 186 for (unsigned i = 0; i < golden.size(); i++) { 187 if (std::fabs(golden[i] - test[i]) > 1.5e-5f) { 188 std::cerr << " output[" << i << "] = " << test[i] 189 << " (should be " << golden[i] << ")\n"; 190 error = error || true; 191 mismatch = mismatch || true; 192 } 193 } 194 } 195 if (mismatch) { 196 std::cerr << "Example: " << example_no++; 197 std::cerr << " failed\n"; 198 } 199 } 200 return error; 201} 202 203class GeneratedTests : public ::testing::Test { 204 protected: 205 virtual void SetUp() { 206 ASSERT_EQ(android::nn::wrapper::Initialize(), 207 android::nn::wrapper::Result::NO_ERROR); 208 } 209 210 virtual void TearDown() { android::nn::wrapper::Shutdown(); } 211}; 212} // namespace 213 214TEST_F(GeneratedTests, add) { 215 ASSERT_EQ( 216 Execute(add::CreateModel, add::examples), 217 0); 218} 219 220TEST_F(GeneratedTests, conv_1_h3_w2_SAME) { 221 ASSERT_EQ( 222 Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples), 223 0); 224} 225 226TEST_F(GeneratedTests, conv_1_h3_w2_VALID) { 227 ASSERT_EQ( 228 Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples), 229 0); 230} 231 232TEST_F(GeneratedTests, conv_3_h3_w2_SAME) { 233 ASSERT_EQ( 234 Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples), 235 0); 236} 237 238TEST_F(GeneratedTests, conv_3_h3_w2_VALID) { 239 ASSERT_EQ( 240 Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples), 241 0); 242} 243 244TEST_F(GeneratedTests, depthwise_conv) { 245 ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples), 246 0); 247} 248 249TEST_F(GeneratedTests, avg_pool_float) { 250 ASSERT_EQ( 251 Execute(avg_pool_float::CreateModel, avg_pool_float::examples), 252 0); 253} 254 255TEST_F(GeneratedTests, max_pool_float) { 256 ASSERT_EQ( 257 Execute(max_pool_float::CreateModel, max_pool_float::examples), 258 0); 259} 260 261TEST_F(GeneratedTests, l2_pool_float) { 262 ASSERT_EQ( 263 Execute(l2_pool_float::CreateModel, l2_pool_float::examples), 264 0); 265} 266 267TEST_F(GeneratedTests, relu_float) { 268 ASSERT_EQ( 269 Execute(relu_float::CreateModel, relu_float::examples), 270 0); 271} 272 273TEST_F(GeneratedTests, relu1_float) { 274 ASSERT_EQ( 275 Execute(relu1_float::CreateModel, relu1_float::examples), 276 0); 277} 278 279TEST_F(GeneratedTests, relu6_float) { 280 ASSERT_EQ( 281 Execute(relu6_float::CreateModel, relu6_float::examples), 282 0); 283} 284 285TEST_F(GeneratedTests, mobilenet) { 286 ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0); 287} 288