TestGenerated.cpp revision 1fe4cea7a58045eb715cade0535220b5fb7f658d
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Top level driver for models and examples generated by test_generator.py
18
19#include "NeuralNetworksWrapper.h"
20#include "TestHarness.h"
21
22//#include <android-base/logging.h>
23#include <gtest/gtest.h>
24#include <cassert>
25#include <cmath>
26#include <iostream>
27#include <map>
28
29namespace generated_tests {
30using namespace android::nn::wrapper;
31
32template <typename T>
33class Example {
34   public:
35    typedef T ElementType;
36    typedef std::pair<std::map<int, std::vector<T>>,
37                      std::map<int, std::vector<T>>>
38        ExampleType;
39
40    static bool Execute(std::function<void(Model*)> create_model,
41                        std::vector<ExampleType>& examples,
42                        std::function<bool(const T, const T)> compare) {
43        Model model;
44        create_model(&model);
45        model.finish();
46
47        int example_no = 1;
48        bool error = false;
49        for (auto& example : examples) {
50            Compilation compilation(&model);
51            compilation.finish();
52            Execution execution(&compilation);
53
54            // Go through all inputs
55            for (auto& i : example.first) {
56                std::vector<T>& input = i.second;
57                execution.setInput(i.first, (const void*)input.data(),
58                                   input.size() * sizeof(T));
59            }
60
61            std::map<int, std::vector<T>> test_outputs;
62
63            assert(example.second.size() == 1);
64            int output_no = 0;
65            for (auto& i : example.second) {
66                std::vector<T>& output = i.second;
67                test_outputs[i.first].resize(output.size());
68                std::vector<T>& test_output = test_outputs[i.first];
69                execution.setOutput(output_no++, (void*)test_output.data(),
70                                    test_output.size() * sizeof(T));
71            }
72            Result r = execution.compute();
73            if (r != Result::NO_ERROR)
74                std::cerr << "Execution was not completed normally\n";
75            bool mismatch = false;
76            for (auto& i : example.second) {
77                const std::vector<T>& test = test_outputs[i.first];
78                const std::vector<T>& golden = i.second;
79                for (unsigned i = 0; i < golden.size(); i++) {
80                    if (compare(golden[i], test[i])) {
81                        std::cerr << " output[" << i << "] = " << (float)test[i]
82                                  << " (should be " << (float)golden[i]
83                                  << ")\n";
84                        error = error || true;
85                        mismatch = mismatch || true;
86                    }
87                }
88            }
89            if (mismatch) {
90                std::cerr << "Example: " << example_no++;
91                std::cerr << " failed\n";
92            }
93        }
94        return error;
95    }
96
97    // Test driver for those generated from ml/nn/runtime/test/spec
98    static void Execute(std::function<void(Model*)> create_model,
99                        std::function<bool(int)> is_ignored,
100                        std::vector<MixedTypedExampleType>& examples) {
101        Model model;
102        create_model(&model);
103        model.finish();
104
105        int example_no = 1;
106        for (auto& example : examples) {
107            SCOPED_TRACE(example_no++);
108            MixedTyped inputs = example.first;
109            const MixedTyped &golden = example.second;
110
111            Compilation compilation(&model);
112            compilation.finish();
113            Execution execution(&compilation);
114
115            // Go through all ty-typed inputs
116            for_all(inputs, [&execution](int idx, auto p, auto s) {
117                ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, p, s));
118            });
119
120            MixedTyped test;
121            // Go through all typed outputs
122            resize_accordingly<float>(golden, test);
123            resize_accordingly<int32_t>(golden, test);
124            resize_accordingly<uint8_t>(golden, test);
125            for_all(test, [&execution](int idx, void* p, auto s) {
126                ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, p, s));
127            });
128
129            Result r = execution.compute();
130            ASSERT_EQ(Result::NO_ERROR, r);
131            // Filter out don't cares
132            MixedTyped filtered_golden;
133            MixedTyped filtered_test;
134            filter<float>(golden, &filtered_golden, is_ignored);
135            filter<float>(test, &filtered_test, is_ignored);
136            filter<int32_t>(golden, &filtered_golden, is_ignored);
137            filter<int32_t>(test, &filtered_test, is_ignored);
138            filter<uint8_t>(golden, &filtered_golden, is_ignored);
139            filter<uint8_t>(test, &filtered_test, is_ignored);
140#define USE_EXPECT_FLOAT_EQ 1
141#ifdef USE_EXPECT_FLOAT_EQ
142            // We want "close-enough" results for float
143            for_each<float>(filtered_golden,
144                            [&filtered_test](int index, auto& m) {
145                auto& test_float_operands =
146                    std::get<Float32Operands>(filtered_test);
147                auto& test_float = test_float_operands[index];
148                for (unsigned int i = 0; i < m.size(); i++) {
149                    SCOPED_TRACE(i);
150                    EXPECT_NEAR(m[i], test_float[i], 1.e-5);
151                }
152            });
153#else  // Use EXPECT_EQ instead; nicer error reporting
154            EXPECT_EQ(std::get<Float32Operands>(filtered_golden),
155                      std::get<Float32Operands>(filtered_test));
156#endif
157            EXPECT_EQ(std::get<Int32Operands>(filtered_golden),
158                      std::get<Int32Operands>(filtered_test));
159            EXPECT_EQ(std::get<Quant8Operands>(filtered_golden),
160                      std::get<Quant8Operands>(filtered_test));
161        }
162    }
163};
164};  // namespace generated_tests
165
166using namespace android::nn::wrapper;
167// Float32 examples
168typedef generated_tests::Example<float>::ExampleType Example;
169// Mixed-typed examples
170typedef generated_tests::MixedTypedExampleType MixedTypedExample;
171
172void Execute(std::function<void(Model*)> create_model,
173             std::function<bool(int)> is_ignored,
174             std::vector<MixedTypedExample>& examples) {
175    generated_tests::Example<float>::Execute(create_model,
176                                             is_ignored, examples);
177}
178
179class GeneratedTests : public ::testing::Test {
180   protected:
181    virtual void SetUp() {
182        // For detailed logs, uncomment this line:
183        // SetMinimumLogSeverity(android::base::VERBOSE);
184    }
185};
186
187// Testcases generated from runtime/test/specs/*.mod.py
188#include "generated/all_generated_tests.cpp"
189// End of testcases generated from runtime/test/specs/*.mod.py
190
191// Below are testcases geneated from TFLite testcases.
192namespace conv_1_h3_w2_SAME {
193std::vector<Example> examples = {
194// Converted examples
195#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
196};
197// Generated model constructor
198#include "generated/models/conv_1_h3_w2_SAME.model.cpp"
199}  // namespace conv_1_h3_w2_SAME
200
201namespace conv_1_h3_w2_VALID {
202std::vector<Example> examples = {
203// Converted examples
204#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
205};
206// Generated model constructor
207#include "generated/models/conv_1_h3_w2_VALID.model.cpp"
208}  // namespace conv_1_h3_w2_VALID
209
210namespace conv_3_h3_w2_SAME {
211std::vector<Example> examples = {
212// Converted examples
213#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
214};
215// Generated model constructor
216#include "generated/models/conv_3_h3_w2_SAME.model.cpp"
217}  // namespace conv_3_h3_w2_SAME
218
219namespace conv_3_h3_w2_VALID {
220std::vector<Example> examples = {
221// Converted examples
222#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
223};
224// Generated model constructor
225#include "generated/models/conv_3_h3_w2_VALID.model.cpp"
226}  // namespace conv_3_h3_w2_VALID
227
228namespace depthwise_conv {
229std::vector<Example> examples = {
230// Converted examples
231#include "generated/examples/depthwise_conv_tests.example.cc"
232};
233// Generated model constructor
234#include "generated/models/depthwise_conv.model.cpp"
235}  // namespace depthwise_conv
236
237namespace mobilenet {
238std::vector<Example> examples = {
239// Converted examples
240#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
241};
242// Generated model constructor
243#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
244}  // namespace mobilenet
245
246namespace {
247bool Execute(std::function<void(Model*)> create_model,
248             std::vector<Example>& examples) {
249    return generated_tests::Example<float>::Execute(
250        create_model, examples, [](float golden, float test) {
251            return std::fabs(golden - test) > 1.5e-5f;
252        });
253}
254}  // namespace
255
256TEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
257    ASSERT_EQ(
258        Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
259        0);
260}
261
262TEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
263    ASSERT_EQ(
264        Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
265        0);
266}
267
268TEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
269    ASSERT_EQ(
270        Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
271        0);
272}
273
274TEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
275    ASSERT_EQ(
276        Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
277        0);
278}
279
280TEST_F(GeneratedTests, depthwise_conv) {
281    ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
282              0);
283}
284
285TEST_F(GeneratedTests, mobilenet) {
286    ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
287}
288