TestGenerated.cpp revision 820215d28bed6c90f696cde0f282445d16da432e
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Top level driver for models and examples generated by test_generator.py
18
19#include "NeuralNetworksWrapper.h"
20#include "TestHarness.h"
21
22#include <gtest/gtest.h>
23#include <cassert>
24#include <cmath>
25#include <iostream>
26#include <map>
27
28namespace generated_tests {
29using namespace android::nn::wrapper;
30
31template <typename T>
32class Example {
33   public:
34    typedef T ElementType;
35    typedef std::pair<std::map<int, std::vector<T>>,
36                      std::map<int, std::vector<T>>>
37        ExampleType;
38
39    static bool Execute(std::function<void(Model*)> create_model,
40                        std::vector<ExampleType>& examples,
41                        std::function<bool(const T, const T)> compare) {
42        Model model;
43        create_model(&model);
44        model.finish();
45
46        int example_no = 1;
47        bool error = false;
48        for (auto& example : examples) {
49            Compilation compilation(&model);
50            compilation.finish();
51            Execution execution(&compilation);
52
53            // Go through all inputs
54            for (auto& i : example.first) {
55                std::vector<T>& input = i.second;
56                // We interpret an empty vector as an optional argument
57                // that has been omitted.
58                if (input.size() == 0) {
59                    execution.setInput(i.first, nullptr, 0);
60                } else {
61                    execution.setInput(i.first, (const void*)input.data(),
62                                       input.size() * sizeof(T));
63                }
64            }
65
66            std::map<int, std::vector<T>> test_outputs;
67
68            assert(example.second.size() == 1);
69            int output_no = 0;
70            for (auto& i : example.second) {
71                std::vector<T>& output = i.second;
72                test_outputs[i.first].resize(output.size());
73                std::vector<T>& test_output = test_outputs[i.first];
74                execution.setOutput(output_no++, (void*)test_output.data(),
75                                    test_output.size() * sizeof(T));
76            }
77            Result r = execution.compute();
78            if (r != Result::NO_ERROR)
79                std::cerr << "Execution was not completed normally\n";
80            bool mismatch = false;
81            for (auto& i : example.second) {
82                const std::vector<T>& test = test_outputs[i.first];
83                const std::vector<T>& golden = i.second;
84                for (unsigned i = 0; i < golden.size(); i++) {
85                    if (compare(golden[i], test[i])) {
86                        std::cerr << " output[" << i << "] = " << (float)test[i]
87                                  << " (should be " << (float)golden[i]
88                                  << ")\n";
89                        error = error || true;
90                        mismatch = mismatch || true;
91                    }
92                }
93            }
94            if (mismatch) {
95                std::cerr << "Example: " << example_no++;
96                std::cerr << " failed\n";
97            }
98        }
99        return error;
100    }
101
102    // Test driver for those generated from ml/nn/runtime/test/spec
103    static void Execute(std::function<void(Model*)> create_model,
104                        std::function<bool(int)> is_ignored,
105                        std::vector<MixedTypedExampleType>& examples) {
106        Model model;
107        create_model(&model);
108        model.finish();
109
110        int example_no = 1;
111        for (auto& example : examples) {
112            SCOPED_TRACE(example_no++);
113            MixedTyped inputs = example.first;
114            const MixedTyped& golden = example.second;
115
116            Compilation compilation(&model);
117            compilation.finish();
118            Execution execution(&compilation);
119
120            // Set all inputs
121            for_all(inputs, [&execution](int idx, const void* p, size_t s) {
122                ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, p, s));
123            });
124
125            MixedTyped test;
126            // Go through all typed outputs
127            resize_accordingly(golden, test);
128            for_all(test, [&execution](int idx, void* p, size_t s) {
129                ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, p, s));
130            });
131
132            Result r = execution.compute();
133            ASSERT_EQ(Result::NO_ERROR, r);
134            // Filter out don't cares
135            MixedTyped filtered_golden;
136            MixedTyped filtered_test;
137            filter(golden, &filtered_golden, is_ignored);
138            filter(test, &filtered_test, is_ignored);
139            // We want "close-enough" results for float
140            compare(filtered_golden, filtered_test);
141        }
142    }
143};
144};  // namespace generated_tests
145
146using namespace android::nn::wrapper;
147// Float32 examples
148typedef generated_tests::Example<float>::ExampleType Example;
149// Mixed-typed examples
150typedef generated_tests::MixedTypedExampleType MixedTypedExample;
151
152void Execute(std::function<void(Model*)> create_model,
153             std::function<bool(int)> is_ignored,
154             std::vector<MixedTypedExample>& examples) {
155    generated_tests::Example<float>::Execute(create_model, is_ignored,
156                                             examples);
157}
158
159class GeneratedTests : public ::testing::Test {
160   protected:
161    virtual void SetUp() {}
162};
163
164// Testcases generated from runtime/test/specs/*.mod.py
165#include "generated/all_generated_tests.cpp"
166// End of testcases generated from runtime/test/specs/*.mod.py
167
168// Below are testcases geneated from TFLite testcases.
169namespace conv_1_h3_w2_SAME {
170std::vector<Example> examples = {
171// Converted examples
172#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
173};
174// Generated model constructor
175#include "generated/models/conv_1_h3_w2_SAME.model.cpp"
176}  // namespace conv_1_h3_w2_SAME
177
178namespace conv_1_h3_w2_VALID {
179std::vector<Example> examples = {
180// Converted examples
181#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
182};
183// Generated model constructor
184#include "generated/models/conv_1_h3_w2_VALID.model.cpp"
185}  // namespace conv_1_h3_w2_VALID
186
187namespace conv_3_h3_w2_SAME {
188std::vector<Example> examples = {
189// Converted examples
190#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
191};
192// Generated model constructor
193#include "generated/models/conv_3_h3_w2_SAME.model.cpp"
194}  // namespace conv_3_h3_w2_SAME
195
196namespace conv_3_h3_w2_VALID {
197std::vector<Example> examples = {
198// Converted examples
199#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
200};
201// Generated model constructor
202#include "generated/models/conv_3_h3_w2_VALID.model.cpp"
203}  // namespace conv_3_h3_w2_VALID
204
205namespace depthwise_conv {
206std::vector<Example> examples = {
207// Converted examples
208#include "generated/examples/depthwise_conv_tests.example.cc"
209};
210// Generated model constructor
211#include "generated/models/depthwise_conv.model.cpp"
212}  // namespace depthwise_conv
213
214namespace mobilenet {
215std::vector<Example> examples = {
216// Converted examples
217#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
218};
219// Generated model constructor
220#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
221}  // namespace mobilenet
222
223namespace {
224bool Execute(std::function<void(Model*)> create_model,
225             std::vector<Example>& examples) {
226    return generated_tests::Example<float>::Execute(
227        create_model, examples, [](float golden, float test) {
228            return std::fabs(golden - test) > 1.5e-5f;
229        });
230}
231}  // namespace
232
233TEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
234    ASSERT_EQ(
235        Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
236        0);
237}
238
239TEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
240    ASSERT_EQ(
241        Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
242        0);
243}
244
245TEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
246    ASSERT_EQ(
247        Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
248        0);
249}
250
251TEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
252    ASSERT_EQ(
253        Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
254        0);
255}
256
257TEST_F(GeneratedTests, depthwise_conv) {
258    ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
259              0);
260}
261
262TEST_F(GeneratedTests, mobilenet) {
263    ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
264}
265