TestGenerated.cpp revision 62cc2758c1c2d303861e209f26bddcf4d7564b73
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Top level driver for models and examples generated by test_generator.py
18
19#include "NeuralNetworksWrapper.h"
20#include "TestHarness.h"
21
22#include <gtest/gtest.h>
23#include <cassert>
24#include <cmath>
25#include <iostream>
26#include <map>
27
28namespace generated_tests {
29using namespace android::nn::wrapper;
30
31template <typename T>
32class Example {
33   public:
34    typedef T ElementType;
35    typedef std::pair<std::map<int, std::vector<T>>,
36                      std::map<int, std::vector<T>>>
37        ExampleType;
38
39    static bool Execute(std::function<void(Model*)> create_model,
40                        std::vector<ExampleType>& examples,
41                        std::function<bool(const T, const T)> compare) {
42        Model model;
43        create_model(&model);
44        model.finish();
45
46        int example_no = 1;
47        bool error = false;
48        for (auto& example : examples) {
49            Compilation compilation(&model);
50            compilation.finish();
51            Execution execution(&compilation);
52
53            // Go through all inputs
54            for (auto& i : example.first) {
55                std::vector<T>& input = i.second;
56                // We interpret an empty vector as an optional argument
57                // that has been omitted.
58                if (input.size() == 0) {
59                    execution.setInput(i.first, nullptr, 0);
60                } else {
61                    execution.setInput(i.first, (const void*)input.data(),
62                                       input.size() * sizeof(T));
63                }
64            }
65
66            std::map<int, std::vector<T>> test_outputs;
67
68            assert(example.second.size() == 1);
69            int output_no = 0;
70            for (auto& i : example.second) {
71                std::vector<T>& output = i.second;
72                test_outputs[i.first].resize(output.size());
73                std::vector<T>& test_output = test_outputs[i.first];
74                execution.setOutput(output_no++, (void*)test_output.data(),
75                                    test_output.size() * sizeof(T));
76            }
77            Result r = execution.compute();
78            if (r != Result::NO_ERROR)
79                std::cerr << "Execution was not completed normally\n";
80            bool mismatch = false;
81            for (auto& i : example.second) {
82                const std::vector<T>& test = test_outputs[i.first];
83                const std::vector<T>& golden = i.second;
84                for (unsigned i = 0; i < golden.size(); i++) {
85                    if (compare(golden[i], test[i])) {
86                        std::cerr << " output[" << i << "] = " << (float)test[i]
87                                  << " (should be " << (float)golden[i]
88                                  << ")\n";
89                        error = error || true;
90                        mismatch = mismatch || true;
91                    }
92                }
93            }
94            if (mismatch) {
95                std::cerr << "Example: " << example_no++;
96                std::cerr << " failed\n";
97            }
98        }
99        return error;
100    }
101
102    // Test driver for those generated from ml/nn/runtime/test/spec
103    static void Execute(std::function<void(Model*)> create_model,
104                        std::function<bool(int)> is_ignored,
105                        std::vector<MixedTypedExampleType>& examples) {
106        Model model;
107        create_model(&model);
108        model.finish();
109
110        int example_no = 1;
111        for (auto& example : examples) {
112            SCOPED_TRACE(example_no++);
113            MixedTyped inputs = example.first;
114            const MixedTyped& golden = example.second;
115
116            Compilation compilation(&model);
117            compilation.finish();
118            Execution execution(&compilation);
119
120            // Set all inputs
121            for_all(inputs, [&execution](int idx, const void* p, size_t s) {
122                ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, p, s));
123            });
124
125            MixedTyped test;
126            // Go through all typed outputs
127            resize_accordingly(golden, test);
128            for_all(test, [&execution](int idx, void* p, size_t s) {
129                ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, p, s));
130            });
131
132            Result r = execution.compute();
133            ASSERT_EQ(Result::NO_ERROR, r);
134            // Filter out don't cares
135            MixedTyped filtered_golden;
136            MixedTyped filtered_test;
137            filter(golden, &filtered_golden, is_ignored);
138            filter(test, &filtered_test, is_ignored);
139            // We want "close-enough" results for float
140            compare(filtered_golden, filtered_test);
141        }
142    }
143};
144};  // namespace generated_tests
145
146using namespace android::nn::wrapper;
147// Float32 examples
148typedef generated_tests::Example<float>::ExampleType Example;
149// Mixed-typed examples
150typedef generated_tests::MixedTypedExampleType MixedTypedExample;
151
152void Execute(std::function<void(Model*)> create_model,
153             std::function<bool(int)> is_ignored,
154             std::vector<MixedTypedExample>& examples) {
155    generated_tests::Example<float>::Execute(create_model, is_ignored,
156                                             examples);
157}
158
159class GeneratedTests : public ::testing::Test {
160   protected:
161    virtual void SetUp() {
162        // For detailed logs, uncomment this line:
163        // SetMinimumLogSeverity(android::base::VERBOSE);
164    }
165};
166
167// Testcases generated from runtime/test/specs/*.mod.py
168#include "generated/all_generated_tests.cpp"
169// End of testcases generated from runtime/test/specs/*.mod.py
170
171// Below are testcases geneated from TFLite testcases.
172namespace conv_1_h3_w2_SAME {
173std::vector<Example> examples = {
174// Converted examples
175#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
176};
177// Generated model constructor
178#include "generated/models/conv_1_h3_w2_SAME.model.cpp"
179}  // namespace conv_1_h3_w2_SAME
180
181namespace conv_1_h3_w2_VALID {
182std::vector<Example> examples = {
183// Converted examples
184#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
185};
186// Generated model constructor
187#include "generated/models/conv_1_h3_w2_VALID.model.cpp"
188}  // namespace conv_1_h3_w2_VALID
189
190namespace conv_3_h3_w2_SAME {
191std::vector<Example> examples = {
192// Converted examples
193#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
194};
195// Generated model constructor
196#include "generated/models/conv_3_h3_w2_SAME.model.cpp"
197}  // namespace conv_3_h3_w2_SAME
198
199namespace conv_3_h3_w2_VALID {
200std::vector<Example> examples = {
201// Converted examples
202#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
203};
204// Generated model constructor
205#include "generated/models/conv_3_h3_w2_VALID.model.cpp"
206}  // namespace conv_3_h3_w2_VALID
207
208namespace depthwise_conv {
209std::vector<Example> examples = {
210// Converted examples
211#include "generated/examples/depthwise_conv_tests.example.cc"
212};
213// Generated model constructor
214#include "generated/models/depthwise_conv.model.cpp"
215}  // namespace depthwise_conv
216
217namespace mobilenet {
218std::vector<Example> examples = {
219// Converted examples
220#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
221};
222// Generated model constructor
223#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
224}  // namespace mobilenet
225
226namespace {
227bool Execute(std::function<void(Model*)> create_model,
228             std::vector<Example>& examples) {
229    return generated_tests::Example<float>::Execute(
230        create_model, examples, [](float golden, float test) {
231            return std::fabs(golden - test) > 1.5e-5f;
232        });
233}
234}  // namespace
235
236TEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
237    ASSERT_EQ(
238        Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
239        0);
240}
241
242TEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
243    ASSERT_EQ(
244        Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
245        0);
246}
247
248TEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
249    ASSERT_EQ(
250        Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
251        0);
252}
253
254TEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
255    ASSERT_EQ(
256        Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
257        0);
258}
259
260TEST_F(GeneratedTests, depthwise_conv) {
261    ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
262              0);
263}
264
265TEST_F(GeneratedTests, mobilenet) {
266    ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
267}
268