TestGenerated.cpp revision 242c6dc1f314646f1a87c66140f26d7623cc399a
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Top level driver for models and examples converted from TFLite tests
18
19#include "NeuralNetworksWrapper.h"
20
21#include <gtest/gtest.h>
22#include <cassert>
23#include <cmath>
24#include <iostream>
25#include <map>
26
27typedef std::pair<std::map<int, std::vector<float>>,
28                  std::map<int, std::vector<float>>>
29    Example;
30
31using namespace android::nn::wrapper;
32
33namespace conv_1_h3_w2_SAME {
34std::vector<Example> examples = {
35// Converted examples
36#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
37};
38// Generated model constructor
39#include "generated/models/conv_1_h3_w2_SAME.model.cpp"
40}  // namespace conv_1_h3_w2_SAME
41
42namespace conv_1_h3_w2_VALID {
43std::vector<Example> examples = {
44// Converted examples
45#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
46};
47// Generated model constructor
48#include "generated/models/conv_1_h3_w2_VALID.model.cpp"
49}  // namespace conv_1_h3_w2_VALID
50
51namespace conv_3_h3_w2_SAME {
52std::vector<Example> examples = {
53// Converted examples
54#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
55};
56// Generated model constructor
57#include "generated/models/conv_3_h3_w2_SAME.model.cpp"
58}  // namespace conv_3_h3_w2_SAME
59
60namespace conv_3_h3_w2_VALID {
61std::vector<Example> examples = {
62// Converted examples
63#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
64};
65// Generated model constructor
66#include "generated/models/conv_3_h3_w2_VALID.model.cpp"
67}  // namespace conv_3_h3_w2_VALID
68
69namespace depthwise_conv {
70std::vector<Example> examples = {
71// Converted examples
72#include "generated/examples/depthwise_conv_tests.example.cc"
73};
74// Generated model constructor
75#include "generated/models/depthwise_conv.model.cpp"
76}  // namespace depthwise_conv
77
78namespace mobilenet {
79std::vector<Example> examples = {
80// Converted examples
81#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
82};
83// Generated model constructor
84#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
85}  // namespace mobilenet
86
87namespace {
88bool Execute(void (*create_model)(Model*), std::vector<Example>& examples) {
89    Model model;
90    create_model(&model);
91
92    int example_no = 1;
93    bool error = false;
94
95    for (auto& example : examples) {
96        Request request(&model);
97
98        // Go through all inputs
99        for (auto& i : example.first) {
100            std::vector<float>& input = i.second;
101            request.setInput(i.first, (const void*)input.data(),
102                             input.size() * sizeof(float));
103        }
104
105        std::map<int, std::vector<float>> test_outputs;
106
107        assert(example.second.size() == 1);
108        int output_no = 0;
109        for (auto& i : example.second) {
110            std::vector<float>& output = i.second;
111            test_outputs[i.first].resize(output.size());
112            std::vector<float>& test_output = test_outputs[i.first];
113            request.setOutput(output_no++, (void*)test_output.data(),
114                              test_output.size() * sizeof(float));
115        }
116        Result r = request.compute();
117        if (r != Result::NO_ERROR)
118            std::cerr << "Request was not completed normally\n";
119        bool mismatch = false;
120        for (auto& i : example.second) {
121            std::vector<float>& test = test_outputs[i.first];
122            std::vector<float>& golden = i.second;
123            for (unsigned i = 0; i < golden.size(); i++) {
124                if (std::fabs(golden[i] - test[i]) > 1.5e-5f) {
125                    std::cerr << " output[" << i << "] = " << test[i]
126                              << " (should be " << golden[i] << ")\n";
127                    error = error || true;
128                    mismatch = mismatch || true;
129                }
130            }
131        }
132        if (mismatch) {
133            std::cerr << "Example: " << example_no++;
134            std::cerr << " failed\n";
135        }
136    }
137    return error;
138}
139
140class GeneratedTests : public ::testing::Test {
141   protected:
142    virtual void SetUp() {
143        ASSERT_EQ(android::nn::wrapper::Initialize(),
144                  android::nn::wrapper::Result::NO_ERROR);
145    }
146
147    virtual void TearDown() { android::nn::wrapper::Shutdown(); }
148};
149}  // namespace
150
151TEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
152    ASSERT_EQ(
153        Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
154        0);
155}
156
157TEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
158    ASSERT_EQ(
159        Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
160        0);
161}
162
163TEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
164    ASSERT_EQ(
165        Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
166        0);
167}
168
169TEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
170    ASSERT_EQ(
171        Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
172        0);
173}
174
175TEST_F(GeneratedTests, depthwise_conv) {
176    ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
177              0);
178}
179
180TEST_F(GeneratedTests, mobilenet) {
181    ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
182}
183