TestGenerated.cpp revision 242c6dc1f314646f1a87c66140f26d7623cc399a
1242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung/*
2242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung * Copyright (C) 2017 The Android Open Source Project
3242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung *
4242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung * Licensed under the Apache License, Version 2.0 (the "License");
5242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung * you may not use this file except in compliance with the License.
6242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung * You may obtain a copy of the License at
7242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung *
8242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung *      http://www.apache.org/licenses/LICENSE-2.0
9242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung *
10242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung * Unless required by applicable law or agreed to in writing, software
11242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung * distributed under the License is distributed on an "AS IS" BASIS,
12242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung * See the License for the specific language governing permissions and
14242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung * limitations under the License.
15242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung */
16242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
17242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Top level driver for models and examples converted from TFLite tests
18242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
19242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "NeuralNetworksWrapper.h"
20242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
21242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include <gtest/gtest.h>
22242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include <cassert>
23242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include <cmath>
24242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include <iostream>
25242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include <map>
26242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
27242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungtypedef std::pair<std::map<int, std::vector<float>>,
28242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung                  std::map<int, std::vector<float>>>
29242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    Example;
30242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
31242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungusing namespace android::nn::wrapper;
32242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
33242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungnamespace conv_1_h3_w2_SAME {
34242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungstd::vector<Example> examples = {
35242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Converted examples
36242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/examples/conv_1_h3_w2_SAME_tests.example.cc"
37242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung};
38242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Generated model constructor
39242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/models/conv_1_h3_w2_SAME.model.cpp"
40242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}  // namespace conv_1_h3_w2_SAME
41242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
42242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungnamespace conv_1_h3_w2_VALID {
43242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungstd::vector<Example> examples = {
44242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Converted examples
45242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/examples/conv_1_h3_w2_VALID_tests.example.cc"
46242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung};
47242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Generated model constructor
48242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/models/conv_1_h3_w2_VALID.model.cpp"
49242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}  // namespace conv_1_h3_w2_VALID
50242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
51242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungnamespace conv_3_h3_w2_SAME {
52242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungstd::vector<Example> examples = {
53242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Converted examples
54242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/examples/conv_3_h3_w2_SAME_tests.example.cc"
55242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung};
56242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Generated model constructor
57242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/models/conv_3_h3_w2_SAME.model.cpp"
58242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}  // namespace conv_3_h3_w2_SAME
59242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
60242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungnamespace conv_3_h3_w2_VALID {
61242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungstd::vector<Example> examples = {
62242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Converted examples
63242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/examples/conv_3_h3_w2_VALID_tests.example.cc"
64242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung};
65242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Generated model constructor
66242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/models/conv_3_h3_w2_VALID.model.cpp"
67242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}  // namespace conv_3_h3_w2_VALID
68242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
69242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungnamespace depthwise_conv {
70242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungstd::vector<Example> examples = {
71242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Converted examples
72242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/examples/depthwise_conv_tests.example.cc"
73242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung};
74242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Generated model constructor
75242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/models/depthwise_conv.model.cpp"
76242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}  // namespace depthwise_conv
77242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
78242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungnamespace mobilenet {
79242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungstd::vector<Example> examples = {
80242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Converted examples
81242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/examples/mobilenet_224_gender_basic_fixed_tests.example.cc"
82242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung};
83242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung// Generated model constructor
84242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung#include "generated/models/mobilenet_224_gender_basic_fixed.model.cpp"
85242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}  // namespace mobilenet
86242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
87242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungnamespace {
88242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungbool Execute(void (*create_model)(Model*), std::vector<Example>& examples) {
89242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    Model model;
90242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    create_model(&model);
91242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
92242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    int example_no = 1;
93242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    bool error = false;
94242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
95242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    for (auto& example : examples) {
96242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        Request request(&model);
97242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
98242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        // Go through all inputs
99242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        for (auto& i : example.first) {
100242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            std::vector<float>& input = i.second;
101242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            request.setInput(i.first, (const void*)input.data(),
102242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung                             input.size() * sizeof(float));
103242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        }
104242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
105242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        std::map<int, std::vector<float>> test_outputs;
106242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
107242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        assert(example.second.size() == 1);
108242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        int output_no = 0;
109242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        for (auto& i : example.second) {
110242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            std::vector<float>& output = i.second;
111242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            test_outputs[i.first].resize(output.size());
112242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            std::vector<float>& test_output = test_outputs[i.first];
113242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            request.setOutput(output_no++, (void*)test_output.data(),
114242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung                              test_output.size() * sizeof(float));
115242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        }
116242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        Result r = request.compute();
117242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        if (r != Result::NO_ERROR)
118242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            std::cerr << "Request was not completed normally\n";
119242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        bool mismatch = false;
120242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        for (auto& i : example.second) {
121242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            std::vector<float>& test = test_outputs[i.first];
122242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            std::vector<float>& golden = i.second;
123242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            for (unsigned i = 0; i < golden.size(); i++) {
124242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung                if (std::fabs(golden[i] - test[i]) > 1.5e-5f) {
125242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung                    std::cerr << " output[" << i << "] = " << test[i]
126242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung                              << " (should be " << golden[i] << ")\n";
127242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung                    error = error || true;
128242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung                    mismatch = mismatch || true;
129242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung                }
130242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            }
131242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        }
132242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        if (mismatch) {
133242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            std::cerr << "Example: " << example_no++;
134242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung            std::cerr << " failed\n";
135242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        }
136242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    }
137242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    return error;
138242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}
139242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
140242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sungclass GeneratedTests : public ::testing::Test {
141242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung   protected:
142242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    virtual void SetUp() {
143242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        ASSERT_EQ(android::nn::wrapper::Initialize(),
144242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung                  android::nn::wrapper::Result::NO_ERROR);
145242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    }
146242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
147242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    virtual void TearDown() { android::nn::wrapper::Shutdown(); }
148242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung};
149242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}  // namespace
150242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
151242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) SungTEST_F(GeneratedTests, conv_1_h3_w2_SAME) {
152242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    ASSERT_EQ(
153242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        Execute(conv_1_h3_w2_SAME::CreateModel, conv_1_h3_w2_SAME::examples),
154242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        0);
155242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}
156242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
157242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) SungTEST_F(GeneratedTests, conv_1_h3_w2_VALID) {
158242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    ASSERT_EQ(
159242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        Execute(conv_1_h3_w2_VALID::CreateModel, conv_1_h3_w2_VALID::examples),
160242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        0);
161242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}
162242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
163242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) SungTEST_F(GeneratedTests, conv_3_h3_w2_SAME) {
164242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    ASSERT_EQ(
165242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        Execute(conv_3_h3_w2_SAME::CreateModel, conv_3_h3_w2_SAME::examples),
166242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        0);
167242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}
168242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
169242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) SungTEST_F(GeneratedTests, conv_3_h3_w2_VALID) {
170242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    ASSERT_EQ(
171242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        Execute(conv_3_h3_w2_VALID::CreateModel, conv_3_h3_w2_VALID::examples),
172242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung        0);
173242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}
174242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
175242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) SungTEST_F(GeneratedTests, depthwise_conv) {
176242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    ASSERT_EQ(Execute(depthwise_conv::CreateModel, depthwise_conv::examples),
177242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung              0);
178242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}
179242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung
180242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) SungTEST_F(GeneratedTests, mobilenet) {
181242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung    ASSERT_EQ(Execute(mobilenet::CreateModel, mobilenet::examples), 0);
182242c6dc1f314646f1a87c66140f26d7623cc399aI-Jui (Ray) Sung}
183