TestGenerated.cpp revision 5d58071b9afb2162bcd53525ce1c8610c5ca2cd8
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Top level driver for models and examples converted from TFLite tests
18
19#include "NeuralNetworksWrapper.h"
20
21#include <gtest/gtest.h>
22#include <cassert>
23#include <cmath>
24#include <iostream>
25#include <map>
26
27typedef std::pair<std::map<int, std::vector<float>>,
28                  std::map<int, std::vector<float>>>
29    Example;
30
31using namespace android::nn::wrapper;
32
33namespace add {
34std::vector<Example> examples = {
35// Generated add
36#include "generated/examples/add_tests.example.cc"
37};
38// Generated model constructor
39#include "generated/models/add.model.cpp"
40}  // add
41
42namespace conv_1_h3_w2_SAME {
43std::vector<Example> examples = {
44// Converted examples
45#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
46};
47// Generated model constructor
48#include "generated/models/conv_1_h3_w2_SAME.model.cpp"
49}  // namespace conv_1_h3_w2_SAME
50
51namespace conv_1_h3_w2_VALID {
52std::vector<Example> examples = {
53// Converted examples
54#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
55};
56// Generated model constructor
57#include "generated/models/conv_1_h3_w2_VALID.model.cpp"
58}  // namespace conv_1_h3_w2_VALID
59
60namespace conv_3_h3_w2_SAME {
61std::vector<Example> examples = {
62// Converted examples
63#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
64};
65// Generated model constructor
66#include "generated/models/conv_3_h3_w2_SAME.model.cpp"
67}  // namespace conv_3_h3_w2_SAME
68
69namespace conv_3_h3_w2_VALID {
70std::vector<Example> examples = {
71// Converted examples
72#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
73};
74// Generated model constructor
75#include "generated/models/conv_3_h3_w2_VALID.model.cpp"
76}  // namespace conv_3_h3_w2_VALID
77
78namespace depthwise_conv {
79std::vector<Example> examples = {
80// Converted examples
81#include "generated/examples/depthwise_conv_tests.example.cc"
82};
83// Generated model constructor
84#include "generated/models/depthwise_conv.model.cpp"
85}  // namespace depthwise_conv
86
87namespace avg_pool_float {
88std::vector<Example> examples = {
89// Generated avg_pool float
90#include "generated/examples/avg_pool_float_tests.example.cc"
91};
92// Generated model constructor
93#include "generated/models/avg_pool_float.model.cpp"
94}  // avg_pool_float
95
96namespace max_pool_float {
97std::vector<Example> examples = {
98// Generated max_pool float
99#include "generated/examples/max_pool_float_tests.example.cc"
100};
101// Generated model constructor
102#include "generated/models/max_pool_float.model.cpp"
103}  // max_pool_float
104
105namespace l2_pool_float {
106std::vector<Example> examples = {
107// Generated l2_pool float
108#include "generated/examples/l2_pool_float_tests.example.cc"
109};
110// Generated model constructor
111#include "generated/models/l2_pool_float.model.cpp"
112}  // l2_pool_float
113
114namespace relu_float {
115std::vector<Example> examples = {
116// Generated relu float
117#include "generated/examples/relu_float_tests.example.cc"
118};
119// Generated model constructor
120#include "generated/models/relu_float.model.cpp"
121}  // relu_float
122
123namespace relu1_float {
124std::vector<Example> examples = {
125// Generated relu1 float
126#include "generated/examples/relu1_float_tests.example.cc"
127};
128// Generated model constructor
129#include "generated/models/relu1_float.model.cpp"
130}  // relu1_float
131
132namespace relu6_float {
133std::vector<Example> examples = {
134// Generated relu6 float
135#include "generated/examples/relu6_float_tests.example.cc"
136};
137// Generated model constructor
138#include "generated/models/relu6_float.model.cpp"
139}  // relu6_float
140
141namespace mobilenet {
142std::vector<Example> examples = {
143// Converted examples
144#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
145};
146// Generated model constructor
147#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
148}  // namespace mobilenet
149
150namespace {
151bool Execute(void (*create_model)(Model*), std::vector<Example>& examples) {
152    Model model;
153    create_model(&model);
154
155    int example_no = 1;
156    bool error = false;
157
158    for (auto& example : examples) {
159        Request request(&model);
160
161        // Go through all inputs
162        for (auto& i : example.first) {
163            std::vector<float>& input = i.second;
164            request.setInput(i.first, (const void*)input.data(),
165                             input.size() * sizeof(float));
166        }
167
168        std::map<int, std::vector<float>> test_outputs;
169
170        assert(example.second.size() == 1);
171        int output_no = 0;
172        for (auto& i : example.second) {
173            std::vector<float>& output = i.second;
174            test_outputs[i.first].resize(output.size());
175            std::vector<float>& test_output = test_outputs[i.first];
176            request.setOutput(output_no++, (void*)test_output.data(),
177                              test_output.size() * sizeof(float));
178        }
179        Result r = request.compute();
180        if (r != Result::NO_ERROR)
181            std::cerr << "Request was not completed normally\n";
182        bool mismatch = false;
183        for (auto& i : example.second) {
184            std::vector<float>& test = test_outputs[i.first];
185            std::vector<float>& golden = i.second;
186            for (unsigned i = 0; i < golden.size(); i++) {
187                if (std::fabs(golden[i] - test[i]) > 1.5e-5f) {
188                    std::cerr << " output[" << i << "] = " << test[i]
189                              << " (should be " << golden[i] << ")\n";
190                    error = error || true;
191                    mismatch = mismatch || true;
192                }
193            }
194        }
195        if (mismatch) {
196            std::cerr << "Example: " << example_no++;
197            std::cerr << " failed\n";
198        }
199    }
200    return error;
201}
202
203class GeneratedTests : public ::testing::Test {
204   protected:
205    virtual void SetUp() {
206        ASSERT_EQ(android::nn::wrapper::Initialize(),
207                  android::nn::wrapper::Result::NO_ERROR);
208    }
209
210    virtual void TearDown() { android::nn::wrapper::Shutdown(); }
211};
212}  // namespace
213
214TEST_F(GeneratedTests, add) {
215    ASSERT_EQ(
216        Execute(add::CreateModel, add::examples),
217        0);
218}
219
220TEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
221    ASSERT_EQ(
222        Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
223        0);
224}
225
226TEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
227    ASSERT_EQ(
228        Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
229        0);
230}
231
232TEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
233    ASSERT_EQ(
234        Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
235        0);
236}
237
238TEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
239    ASSERT_EQ(
240        Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
241        0);
242}
243
244TEST_F(GeneratedTests, depthwise_conv) {
245    ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
246              0);
247}
248
249TEST_F(GeneratedTests, avg_pool_float) {
250    ASSERT_EQ(
251        Execute(avg_pool_float::CreateModel, avg_pool_float::examples),
252        0);
253}
254
255TEST_F(GeneratedTests, max_pool_float) {
256    ASSERT_EQ(
257        Execute(max_pool_float::CreateModel, max_pool_float::examples),
258        0);
259}
260
261TEST_F(GeneratedTests, l2_pool_float) {
262    ASSERT_EQ(
263        Execute(l2_pool_float::CreateModel, l2_pool_float::examples),
264        0);
265}
266
267TEST_F(GeneratedTests, relu_float) {
268    ASSERT_EQ(
269        Execute(relu_float::CreateModel, relu_float::examples),
270        0);
271}
272
273TEST_F(GeneratedTests, relu1_float) {
274    ASSERT_EQ(
275        Execute(relu1_float::CreateModel, relu1_float::examples),
276        0);
277}
278
279TEST_F(GeneratedTests, relu6_float) {
280    ASSERT_EQ(
281        Execute(relu6_float::CreateModel, relu6_float::examples),
282        0);
283}
284
285TEST_F(GeneratedTests, mobilenet) {
286    ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
287}
288