TestGenerated.cpp revision 544739620cd7f37d40524d2407c92042e485c73f
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Top level driver for models and examples generated by test_generator.py
18
19#include "NeuralNetworksWrapper.h"
20#include "TestHarness.h"
21
22#include <gtest/gtest.h>
23#include <cassert>
24#include <cmath>
25#include <functional>
26#include <iostream>
27#include <map>
28
29namespace generated_tests {
30using namespace android::nn::wrapper;
31
32template <typename T>
33class Example {
34   public:
35    typedef T ElementType;
36    typedef std::pair<std::map<int, std::vector<T>>,
37                      std::map<int, std::vector<T>>>
38        ExampleType;
39
40    static bool Execute(std::function<void(Model*)> create_model,
41                        std::vector<ExampleType>& examples,
42                        std::function<bool(const T, const T)> compare) {
43        Model model;
44        create_model(&model);
45        model.finish();
46
47        int example_no = 1;
48        bool error = false;
49        for (auto& example : examples) {
50            Compilation compilation(&model);
51            compilation.compile();
52            Request request(&compilation);
53
54            // Go through all inputs
55            for (auto& i : example.first) {
56                std::vector<T>& input = i.second;
57                request.setInput(i.first, (const void*)input.data(),
58                                 input.size() * sizeof(T));
59            }
60
61            std::map<int, std::vector<T>> test_outputs;
62
63            assert(example.second.size() == 1);
64            int output_no = 0;
65            for (auto& i : example.second) {
66                std::vector<T>& output = i.second;
67                test_outputs[i.first].resize(output.size());
68                std::vector<T>& test_output = test_outputs[i.first];
69                request.setOutput(output_no++, (void*)test_output.data(),
70                                  test_output.size() * sizeof(T));
71            }
72            Result r = request.compute();
73            if (r != Result::NO_ERROR)
74                std::cerr << "Request was not completed normally\n";
75            bool mismatch = false;
76            for (auto& i : example.second) {
77                const std::vector<T>& test = test_outputs[i.first];
78                const std::vector<T>& golden = i.second;
79                for (unsigned i = 0; i < golden.size(); i++) {
80                    if (compare(golden[i], test[i])) {
81                        std::cerr << " output[" << i << "] = " << (float)test[i]
82                                  << " (should be " << (float)golden[i]
83                                  << ")\n";
84                        error = error || true;
85                        mismatch = mismatch || true;
86                    }
87                }
88            }
89            if (mismatch) {
90                std::cerr << "Example: " << example_no++;
91                std::cerr << " failed\n";
92            }
93        }
94        return error;
95    }
96
97    // Test driver for those generated from ml/nn/runtime/test/spec
98    static void Execute(std::function<void(Model*)> create_model,
99                        std::vector<MixedTypedExampleType>& examples) {
100        Model model;
101        create_model(&model);
102        model.finish();
103
104        int example_no = 1;
105        for (auto& example : examples) {
106            SCOPED_TRACE(example_no++);
107
108            MixedTyped& inputs = example.first;
109            MixedTyped& golden = example.second;
110
111            Compilation compilation(&model);
112            compilation.compile();
113            Request request(&compilation);
114
115            // Go through all ty-typed inputs
116#define SET_TYPED_TENSOR_INPUT(ty)                                     \
117    for (auto& i : std::get<std::map<int, std::vector<ty>>>(inputs)) { \
118        request.setInput(i.first, (const void*)i.second.data(),        \
119                         i.second.size() * sizeof(ty));                \
120    }
121
122            SET_TYPED_TENSOR_INPUT(float);
123            SET_TYPED_TENSOR_INPUT(int32_t);
124            SET_TYPED_TENSOR_INPUT(uint8_t);
125#undef SET_TYPED_TENSOR_INPUT
126
127            MixedTyped test;
128            // Go through all typed outputs
129#define SET_TYPED_OUTPUT(ty)                                              \
130    auto& golden_##ty = std::get<std::map<int, std::vector<ty>>>(golden); \
131    auto& test_##ty = std::get<std::map<int, std::vector<ty>>>(test);     \
132    for (auto& i : golden_##ty) {                                         \
133        int idx = i.first;                                                \
134        auto& golden_output = i.second;                                   \
135        test_##ty[idx].resize(golden_output.size());                      \
136        request.setOutput(idx, (void*)test_##ty[idx].data(),              \
137                          test_##ty[idx].size() * sizeof(ty));            \
138    }
139            SET_TYPED_OUTPUT(float);
140            SET_TYPED_OUTPUT(int32_t);
141            SET_TYPED_OUTPUT(uint8_t);
142#undef SET_TYPED_OUTPUT
143
144            Result r = request.compute();
145            ASSERT_EQ(Result::NO_ERROR, r);
146
147            EXPECT_EQ(golden_float, test_float);
148            EXPECT_EQ(golden_int32_t, test_int32_t);
149            EXPECT_EQ(golden_uint8_t, test_uint8_t);
150        }
151    }
152};
153};  // namespace generated_tests
154
155using namespace android::nn::wrapper;
156// Float32 examples
157typedef generated_tests::Example<float>::ExampleType Example;
158// Mixed-typed examples
159typedef generated_tests::MixedTypedExampleType MixedTypedExample;
160
161void Execute(std::function<void(Model*)> create_model,
162             std::vector<MixedTypedExample>& examples) {
163    generated_tests::Example<float>::Execute(create_model, examples);
164}
165
166class GeneratedTests : public ::testing::Test {
167   protected:
168    virtual void SetUp() {
169        ASSERT_EQ(android::nn::wrapper::Initialize(),
170                  android::nn::wrapper::Result::NO_ERROR);
171    }
172
173    virtual void TearDown() { android::nn::wrapper::Shutdown(); }
174};
175
176// Testcases generated from runtime/test/specs/*.mod.py
177#include "generated/all_generated_tests.cpp"
178// End of testcases generated from runtime/test/specs/*.mod.py
179
180// Below are testcases geneated from TFLite testcases.
181namespace conv_1_h3_w2_SAME {
182std::vector<Example> examples = {
183// Converted examples
184#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
185};
186// Generated model constructor
187#include "generated/models/conv_1_h3_w2_SAME.model.cpp"
188}  // namespace conv_1_h3_w2_SAME
189
190namespace conv_1_h3_w2_VALID {
191std::vector<Example> examples = {
192// Converted examples
193#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
194};
195// Generated model constructor
196#include "generated/models/conv_1_h3_w2_VALID.model.cpp"
197}  // namespace conv_1_h3_w2_VALID
198
199namespace conv_3_h3_w2_SAME {
200std::vector<Example> examples = {
201// Converted examples
202#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
203};
204// Generated model constructor
205#include "generated/models/conv_3_h3_w2_SAME.model.cpp"
206}  // namespace conv_3_h3_w2_SAME
207
208namespace conv_3_h3_w2_VALID {
209std::vector<Example> examples = {
210// Converted examples
211#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
212};
213// Generated model constructor
214#include "generated/models/conv_3_h3_w2_VALID.model.cpp"
215}  // namespace conv_3_h3_w2_VALID
216
217namespace depthwise_conv {
218std::vector<Example> examples = {
219// Converted examples
220#include "generated/examples/depthwise_conv_tests.example.cc"
221};
222// Generated model constructor
223#include "generated/models/depthwise_conv.model.cpp"
224}  // namespace depthwise_conv
225
226namespace mobilenet {
227std::vector<Example> examples = {
228// Converted examples
229#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
230};
231// Generated model constructor
232#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
233}  // namespace mobilenet
234
235namespace {
236bool Execute(std::function<void(Model*)> create_model,
237             std::vector<Example>& examples) {
238    return generated_tests::Example<float>::Execute(
239        create_model, examples, [](float golden, float test) {
240            return std::fabs(golden - test) > 1.5e-5f;
241        });
242}
243}  // namespace
244
245
246TEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
247    ASSERT_EQ(
248        Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
249        0);
250}
251
252TEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
253    ASSERT_EQ(
254        Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
255        0);
256}
257
258TEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
259    ASSERT_EQ(
260        Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
261        0);
262}
263
264TEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
265    ASSERT_EQ(
266        Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
267        0);
268}
269
270TEST_F(GeneratedTests, depthwise_conv) {
271    ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
272              0);
273}
274
275TEST_F(GeneratedTests, mobilenet) {
276    ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
277}
278