TestGenerated.cpp revision 6a0d306cf902e13ab147c7533b2cb02540ee66d5
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Top level driver for models and examples generated by test_generator.py
18
19#include "NeuralNetworksWrapper.h"
20#include "TestHarness.h"
21
22#include <gtest/gtest.h>
23#include <cassert>
24#include <cmath>
25#include <functional>
26#include <iostream>
27#include <map>
28
29namespace generated_tests {
30using namespace android::nn::wrapper;
31
32template <typename T>
33class Example {
34   public:
35    typedef T ElementType;
36    typedef std::pair<std::map<int, std::vector<T>>,
37                      std::map<int, std::vector<T>>>
38        ExampleType;
39
40    static bool Execute(std::function<void(Model*)> create_model,
41                        std::vector<ExampleType>& examples,
42                        std::function<bool(const T, const T)> compare) {
43        Model model;
44        create_model(&model);
45
46        int example_no = 1;
47        bool error = false;
48        for (auto& example : examples) {
49            Request request(&model);
50
51            // Go through all inputs
52            for (auto& i : example.first) {
53                std::vector<T>& input = i.second;
54                request.setInput(i.first, (const void*)input.data(),
55                                 input.size() * sizeof(T));
56            }
57
58            std::map<int, std::vector<T>> test_outputs;
59
60            assert(example.second.size() == 1);
61            int output_no = 0;
62            for (auto& i : example.second) {
63                std::vector<T>& output = i.second;
64                test_outputs[i.first].resize(output.size());
65                std::vector<T>& test_output = test_outputs[i.first];
66                request.setOutput(output_no++, (void*)test_output.data(),
67                                  test_output.size() * sizeof(T));
68            }
69            Result r = request.compute();
70            if (r != Result::NO_ERROR)
71                std::cerr << "Request was not completed normally\n";
72            bool mismatch = false;
73            for (auto& i : example.second) {
74                const std::vector<T>& test = test_outputs[i.first];
75                const std::vector<T>& golden = i.second;
76                for (unsigned i = 0; i < golden.size(); i++) {
77                    if (compare(golden[i], test[i])) {
78                        std::cerr << " output[" << i << "] = " << (float)test[i]
79                                  << " (should be " << (float)golden[i]
80                                  << ")\n";
81                        error = error || true;
82                        mismatch = mismatch || true;
83                    }
84                }
85            }
86            if (mismatch) {
87                std::cerr << "Example: " << example_no++;
88                std::cerr << " failed\n";
89            }
90        }
91        return error;
92    }
93
94    // Test driver for those generated from ml/nn/runtime/test/spec
95    static void Execute(std::function<void(Model*)> create_model,
96                        std::vector<MixedTypedExampleType>& examples) {
97        Model model;
98        create_model(&model);
99
100        int example_no = 1;
101        for (auto& example : examples) {
102            SCOPED_TRACE(example_no++);
103
104            MixedTyped& inputs = example.first;
105            MixedTyped& golden = example.second;
106
107            Request request(&model);
108
109            // Go through all ty-typed inputs
110#define SET_TYPED_TENSOR_INPUT(ty)                                     \
111    for (auto& i : std::get<std::map<int, std::vector<ty>>>(inputs)) { \
112        request.setInput(i.first, (const void*)i.second.data(),        \
113                         i.second.size() * sizeof(ty));                \
114    }
115
116            SET_TYPED_TENSOR_INPUT(float);
117            SET_TYPED_TENSOR_INPUT(int32_t);
118            SET_TYPED_TENSOR_INPUT(uint8_t);
119#undef SET_TYPED_TENSOR_INPUT
120
121            MixedTyped test;
122            // Go through all typed outputs
123#define SET_TYPED_OUTPUT(ty)                                              \
124    auto& golden_##ty = std::get<std::map<int, std::vector<ty>>>(golden); \
125    auto& test_##ty = std::get<std::map<int, std::vector<ty>>>(test);     \
126    for (auto& i : golden_##ty) {                                         \
127        int idx = i.first;                                                \
128        auto& golden_output = i.second;                                   \
129        test_##ty[idx].resize(golden_output.size());                      \
130        request.setOutput(idx, (void*)test_##ty[idx].data(),              \
131                          test_##ty[idx].size() * sizeof(ty));            \
132    }
133            SET_TYPED_OUTPUT(float);
134            SET_TYPED_OUTPUT(int32_t);
135            SET_TYPED_OUTPUT(uint8_t);
136#undef SET_TYPED_OUTPUT
137
138            Result r = request.compute();
139            ASSERT_EQ(Result::NO_ERROR, r);
140
141            EXPECT_EQ(golden_float, test_float);
142            EXPECT_EQ(golden_int32_t, test_int32_t);
143            EXPECT_EQ(golden_uint8_t, test_uint8_t);
144        }
145    }
146};
147};  // namespace generated_tests
148
149using namespace android::nn::wrapper;
150// Float32 examples
151typedef generated_tests::Example<float>::ExampleType Example;
152// Mixed-typed examples
153typedef generated_tests::MixedTypedExampleType MixedTypedExample;
154
155void Execute(std::function<void(Model*)> create_model,
156             std::vector<MixedTypedExample>& examples) {
157    generated_tests::Example<float>::Execute(create_model, examples);
158}
159
160class GeneratedTests : public ::testing::Test {
161   protected:
162    virtual void SetUp() {
163        ASSERT_EQ(android::nn::wrapper::Initialize(),
164                  android::nn::wrapper::Result::NO_ERROR);
165    }
166
167    virtual void TearDown() { android::nn::wrapper::Shutdown(); }
168};
169
170// Testcases generated from runtime/test/specs/*.mod.py
171#include "generated/all_generated_tests.cpp"
172// End of testcases generated from runtime/test/specs/*.mod.py
173
174// Below are testcases geneated from TFLite testcases.
175namespace conv_1_h3_w2_SAME {
176std::vector<Example> examples = {
177// Converted examples
178#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
179};
180// Generated model constructor
181#include "generated/models/conv_1_h3_w2_SAME.model.cpp"
182}  // namespace conv_1_h3_w2_SAME
183
184namespace conv_1_h3_w2_VALID {
185std::vector<Example> examples = {
186// Converted examples
187#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
188};
189// Generated model constructor
190#include "generated/models/conv_1_h3_w2_VALID.model.cpp"
191}  // namespace conv_1_h3_w2_VALID
192
193namespace conv_3_h3_w2_SAME {
194std::vector<Example> examples = {
195// Converted examples
196#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
197};
198// Generated model constructor
199#include "generated/models/conv_3_h3_w2_SAME.model.cpp"
200}  // namespace conv_3_h3_w2_SAME
201
202namespace conv_3_h3_w2_VALID {
203std::vector<Example> examples = {
204// Converted examples
205#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
206};
207// Generated model constructor
208#include "generated/models/conv_3_h3_w2_VALID.model.cpp"
209}  // namespace conv_3_h3_w2_VALID
210
211namespace depthwise_conv {
212std::vector<Example> examples = {
213// Converted examples
214#include "generated/examples/depthwise_conv_tests.example.cc"
215};
216// Generated model constructor
217#include "generated/models/depthwise_conv.model.cpp"
218}  // namespace depthwise_conv
219
220namespace mobilenet {
221std::vector<Example> examples = {
222// Converted examples
223#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
224};
225// Generated model constructor
226#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
227}  // namespace mobilenet
228
229namespace {
230bool Execute(std::function<void(Model*)> create_model,
231             std::vector<Example>& examples) {
232    return generated_tests::Example<float>::Execute(
233        create_model, examples, [](float golden, float test) {
234            return std::fabs(golden - test) > 1.5e-5f;
235        });
236}
237}  // namespace
238
239
240TEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
241    ASSERT_EQ(
242        Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
243        0);
244}
245
246TEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
247    ASSERT_EQ(
248        Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
249        0);
250}
251
252TEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
253    ASSERT_EQ(
254        Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
255        0);
256}
257
258TEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
259    ASSERT_EQ(
260        Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
261        0);
262}
263
264TEST_F(GeneratedTests, depthwise_conv) {
265    ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
266              0);
267}
268
269TEST_F(GeneratedTests, mobilenet) {
270    ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
271}
272