TestGenerated.cpp revision dcd2fbf1da7ebdc1aa1b57c74db4fffd2911e3b8
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Top level driver for models and examples converted from TFLite tests
18
19#include "NeuralNetworksWrapper.h"
20
21#include <gtest/gtest.h>
22#include <cassert>
23#include <cmath>
24#include <iostream>
25#include <map>
26
27typedef std::pair<std::map<int, std::vector<float>>,
28                  std::map<int, std::vector<float>>>
29    Example;
30
31using namespace android::nn::wrapper;
32
33namespace add {
34std::vector<Example> examples = {
35// Generated add
36#include "generated/examples/add_tests.example.cc"
37};
38// Generated model constructor
39#include "generated/models/add.model.cpp"
40}  // add
41
42namespace conv_1_h3_w2_SAME {
43std::vector<Example> examples = {
44// Converted examples
45#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
46};
47// Generated model constructor
48#include "generated/models/conv_1_h3_w2_SAME.model.cpp"
49}  // namespace conv_1_h3_w2_SAME
50
51namespace conv_1_h3_w2_VALID {
52std::vector<Example> examples = {
53// Converted examples
54#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
55};
56// Generated model constructor
57#include "generated/models/conv_1_h3_w2_VALID.model.cpp"
58}  // namespace conv_1_h3_w2_VALID
59
60namespace conv_3_h3_w2_SAME {
61std::vector<Example> examples = {
62// Converted examples
63#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
64};
65// Generated model constructor
66#include "generated/models/conv_3_h3_w2_SAME.model.cpp"
67}  // namespace conv_3_h3_w2_SAME
68
69namespace conv_3_h3_w2_VALID {
70std::vector<Example> examples = {
71// Converted examples
72#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
73};
74// Generated model constructor
75#include "generated/models/conv_3_h3_w2_VALID.model.cpp"
76}  // namespace conv_3_h3_w2_VALID
77
78namespace depthwise_conv {
79std::vector<Example> examples = {
80// Converted examples
81#include "generated/examples/depthwise_conv_tests.example.cc"
82};
83// Generated model constructor
84#include "generated/models/depthwise_conv.model.cpp"
85}  // namespace depthwise_conv
86
87namespace mobilenet {
88std::vector<Example> examples = {
89// Converted examples
90#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
91};
92// Generated model constructor
93#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
94}  // namespace mobilenet
95
96namespace {
97bool Execute(void (*create_model)(Model*), std::vector<Example>& examples) {
98    Model model;
99    create_model(&model);
100
101    int example_no = 1;
102    bool error = false;
103
104    for (auto& example : examples) {
105        Request request(&model);
106
107        // Go through all inputs
108        for (auto& i : example.first) {
109            std::vector<float>& input = i.second;
110            request.setInput(i.first, (const void*)input.data(),
111                             input.size() * sizeof(float));
112        }
113
114        std::map<int, std::vector<float>> test_outputs;
115
116        assert(example.second.size() == 1);
117        int output_no = 0;
118        for (auto& i : example.second) {
119            std::vector<float>& output = i.second;
120            test_outputs[i.first].resize(output.size());
121            std::vector<float>& test_output = test_outputs[i.first];
122            request.setOutput(output_no++, (void*)test_output.data(),
123                              test_output.size() * sizeof(float));
124        }
125        Result r = request.compute();
126        if (r != Result::NO_ERROR)
127            std::cerr << "Request was not completed normally\n";
128        bool mismatch = false;
129        for (auto& i : example.second) {
130            std::vector<float>& test = test_outputs[i.first];
131            std::vector<float>& golden = i.second;
132            for (unsigned i = 0; i < golden.size(); i++) {
133                if (std::fabs(golden[i] - test[i]) > 1.5e-5f) {
134                    std::cerr << " output[" << i << "] = " << test[i]
135                              << " (should be " << golden[i] << ")\n";
136                    error = error || true;
137                    mismatch = mismatch || true;
138                }
139            }
140        }
141        if (mismatch) {
142            std::cerr << "Example: " << example_no++;
143            std::cerr << " failed\n";
144        }
145    }
146    return error;
147}
148
149class GeneratedTests : public ::testing::Test {
150   protected:
151    virtual void SetUp() {
152        ASSERT_EQ(android::nn::wrapper::Initialize(),
153                  android::nn::wrapper::Result::NO_ERROR);
154    }
155
156    virtual void TearDown() { android::nn::wrapper::Shutdown(); }
157};
158}  // namespace
159
160TEST_F(GeneratedTests, add) {
161    ASSERT_EQ(
162        Execute(add::CreateModel, add::examples),
163        0);
164}
165
166TEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
167    ASSERT_EQ(
168        Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
169        0);
170}
171
172TEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
173    ASSERT_EQ(
174        Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
175        0);
176}
177
178TEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
179    ASSERT_EQ(
180        Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
181        0);
182}
183
184TEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
185    ASSERT_EQ(
186        Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
187        0);
188}
189
190TEST_F(GeneratedTests, depthwise_conv) {
191    ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
192              0);
193}
194
195TEST_F(GeneratedTests, mobilenet) {
196    ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
197}
198