1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17/* Header-only library for various helpers of test harness
18 * See frameworks/ml/nn/runtime/test/TestGenerated.cpp for how this is used.
19 */
20#ifndef ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
21#define ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
22
23#include <gtest/gtest.h>
24
25#include <functional>
26#include <map>
27#include <tuple>
28#include <vector>
29
30namespace generated_tests {
31typedef std::map<int, std::vector<float>> Float32Operands;
32typedef std::map<int, std::vector<int32_t>> Int32Operands;
33typedef std::map<int, std::vector<uint8_t>> Quant8Operands;
34typedef std::tuple<Float32Operands,  // ANEURALNETWORKS_TENSOR_FLOAT32
35                   Int32Operands,    // ANEURALNETWORKS_TENSOR_INT32
36                   Quant8Operands    // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
37                   >
38        MixedTyped;
39typedef std::pair<MixedTyped, MixedTyped> MixedTypedExampleType;
40
41template <typename T>
42struct MixedTypedIndex {};
43
44template <>
45struct MixedTypedIndex<float> {
46    static constexpr size_t index = 0;
47};
48template <>
49struct MixedTypedIndex<int32_t> {
50    static constexpr size_t index = 1;
51};
52template <>
53struct MixedTypedIndex<uint8_t> {
54    static constexpr size_t index = 2;
55};
56
57// Go through all index-value pairs of a given input type
58template <typename T>
59inline void for_each(const MixedTyped& idx_and_data,
60                     std::function<void(int, const std::vector<T>&)> execute) {
61    for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) {
62        execute(i.first, i.second);
63    }
64}
65
66// non-const variant of for_each
67template <typename T>
68inline void for_each(MixedTyped& idx_and_data,
69                     std::function<void(int, std::vector<T>&)> execute) {
70    for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) {
71        execute(i.first, i.second);
72    }
73}
74
75// internal helper for for_all
76template <typename T>
77inline void for_all_internal(
78        MixedTyped& idx_and_data,
79        std::function<void(int, void*, size_t)> execute_this) {
80    for_each<T>(idx_and_data, [&execute_this](int idx, std::vector<T>& m) {
81        execute_this(idx, static_cast<void*>(m.data()), m.size() * sizeof(T));
82    });
83}
84
85// Go through all index-value pairs of all input types
86// expects a functor that takes (int index, void *raw data, size_t sz)
87inline void for_all(MixedTyped& idx_and_data,
88                    std::function<void(int, void*, size_t)> execute_this) {
89    for_all_internal<float>(idx_and_data, execute_this);
90    for_all_internal<int32_t>(idx_and_data, execute_this);
91    for_all_internal<uint8_t>(idx_and_data, execute_this);
92}
93
94// Const variant of internal helper for for_all
95template <typename T>
96inline void for_all_internal(
97        const MixedTyped& idx_and_data,
98        std::function<void(int, const void*, size_t)> execute_this) {
99    for_each<T>(idx_and_data, [&execute_this](int idx, const std::vector<T>& m) {
100        execute_this(idx, static_cast<const void*>(m.data()), m.size() * sizeof(T));
101    });
102}
103
104// Go through all index-value pairs (const variant)
105// expects a functor that takes (int index, const void *raw data, size_t sz)
106inline void for_all(
107        const MixedTyped& idx_and_data,
108        std::function<void(int, const void*, size_t)> execute_this) {
109    for_all_internal<float>(idx_and_data, execute_this);
110    for_all_internal<int32_t>(idx_and_data, execute_this);
111    for_all_internal<uint8_t>(idx_and_data, execute_this);
112}
113
114// Helper template - resize test output per golden
115template <typename ty, size_t tuple_index>
116void resize_accordingly_(const MixedTyped& golden, MixedTyped& test) {
117    std::function<void(int, const std::vector<ty>&)> execute =
118            [&test](int index, const std::vector<ty>& m) {
119                auto& t = std::get<tuple_index>(test);
120                t[index].resize(m.size());
121            };
122    for_each<ty>(golden, execute);
123}
124
125inline void resize_accordingly(const MixedTyped& golden, MixedTyped& test) {
126    resize_accordingly_<float, 0>(golden, test);
127    resize_accordingly_<int32_t, 1>(golden, test);
128    resize_accordingly_<uint8_t, 2>(golden, test);
129}
130
131template <typename ty, size_t tuple_index>
132void filter_internal(const MixedTyped& golden, MixedTyped* filtered,
133                     std::function<bool(int)> is_ignored) {
134    for_each<ty>(golden,
135                 [filtered, &is_ignored](int index, const std::vector<ty>& m) {
136                     auto& g = std::get<tuple_index>(*filtered);
137                     if (!is_ignored(index)) g[index] = m;
138                 });
139}
140
141inline MixedTyped filter(const MixedTyped& golden,
142                         std::function<bool(int)> is_ignored) {
143    MixedTyped filtered;
144    filter_internal<float, 0>(golden, &filtered, is_ignored);
145    filter_internal<int32_t, 1>(golden, &filtered, is_ignored);
146    filter_internal<uint8_t, 2>(golden, &filtered, is_ignored);
147    return filtered;
148}
149
150// Compare results
151#define VECTOR_TYPE(x) \
152    typename std::tuple_element<x, MixedTyped>::type::mapped_type
153#define VALUE_TYPE(x) VECTOR_TYPE(x)::value_type
154template <size_t tuple_index>
155void compare_(
156        const MixedTyped& golden, const MixedTyped& test,
157        std::function<void(VALUE_TYPE(tuple_index), VALUE_TYPE(tuple_index))>
158                cmp) {
159    for_each<VALUE_TYPE(tuple_index)>(
160            golden,
161            [&test, &cmp](int index, const VECTOR_TYPE(tuple_index) & m) {
162                const auto& test_operands = std::get<tuple_index>(test);
163                const auto& test_ty = test_operands.find(index);
164                ASSERT_NE(test_ty, test_operands.end());
165                for (unsigned int i = 0; i < m.size(); i++) {
166                    SCOPED_TRACE(testing::Message()
167                                 << "When comparing element " << i);
168                    cmp(m[i], test_ty->second[i]);
169                }
170            });
171}
172#undef VALUE_TYPE
173#undef VECTOR_TYPE
174inline void compare(const MixedTyped& golden, const MixedTyped& test) {
175    compare_<0>(golden, test,
176                [](float g, float t) { EXPECT_NEAR(g, t, 1.e-5f); });
177    compare_<1>(golden, test, [](int32_t g, int32_t t) { EXPECT_EQ(g, t); });
178    compare_<2>(golden, test, [](uint8_t g, uint8_t t) { EXPECT_NEAR(g, t, 1); });
179}
180
181};  // namespace generated_tests
182
183#endif  // ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
184