1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17/* Header-only library for various helpers of test harness 18 * See frameworks/ml/nn/runtime/test/TestGenerated.cpp for how this is used. 19 */ 20#ifndef ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H 21#define ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H 22 23#include <gtest/gtest.h> 24 25#include <functional> 26#include <map> 27#include <tuple> 28#include <vector> 29 30namespace generated_tests { 31typedef std::map<int, std::vector<float>> Float32Operands; 32typedef std::map<int, std::vector<int32_t>> Int32Operands; 33typedef std::map<int, std::vector<uint8_t>> Quant8Operands; 34typedef std::tuple<Float32Operands, // ANEURALNETWORKS_TENSOR_FLOAT32 35 Int32Operands, // ANEURALNETWORKS_TENSOR_INT32 36 Quant8Operands // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM 37 > 38 MixedTyped; 39typedef std::pair<MixedTyped, MixedTyped> MixedTypedExampleType; 40 41template <typename T> 42struct MixedTypedIndex {}; 43 44template <> 45struct MixedTypedIndex<float> { 46 static constexpr size_t index = 0; 47}; 48template <> 49struct MixedTypedIndex<int32_t> { 50 static constexpr size_t index = 1; 51}; 52template <> 53struct MixedTypedIndex<uint8_t> { 54 static constexpr size_t index = 2; 55}; 56 57// Go through all index-value pairs of a given input type 58template <typename T> 59inline void for_each(const MixedTyped& idx_and_data, 60 std::function<void(int, const std::vector<T>&)> execute) { 61 for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) { 62 execute(i.first, i.second); 63 } 64} 65 66// non-const variant of for_each 67template <typename T> 68inline void for_each(MixedTyped& idx_and_data, 69 std::function<void(int, std::vector<T>&)> execute) { 70 for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) { 71 execute(i.first, i.second); 72 } 73} 74 75// internal helper for for_all 76template <typename T> 77inline void for_all_internal( 78 MixedTyped& idx_and_data, 79 std::function<void(int, void*, size_t)> execute_this) { 80 for_each<T>(idx_and_data, [&execute_this](int idx, std::vector<T>& m) { 81 execute_this(idx, static_cast<void*>(m.data()), m.size() * sizeof(T)); 82 }); 83} 84 85// Go through all index-value pairs of all input types 86// expects a functor that takes (int index, void *raw data, size_t sz) 87inline void for_all(MixedTyped& idx_and_data, 88 std::function<void(int, void*, size_t)> execute_this) { 89 for_all_internal<float>(idx_and_data, execute_this); 90 for_all_internal<int32_t>(idx_and_data, execute_this); 91 for_all_internal<uint8_t>(idx_and_data, execute_this); 92} 93 94// Const variant of internal helper for for_all 95template <typename T> 96inline void for_all_internal( 97 const MixedTyped& idx_and_data, 98 std::function<void(int, const void*, size_t)> execute_this) { 99 for_each<T>(idx_and_data, [&execute_this](int idx, const std::vector<T>& m) { 100 execute_this(idx, static_cast<const void*>(m.data()), m.size() * sizeof(T)); 101 }); 102} 103 104// Go through all index-value pairs (const variant) 105// expects a functor that takes (int index, const void *raw data, size_t sz) 106inline void for_all( 107 const MixedTyped& idx_and_data, 108 std::function<void(int, const void*, size_t)> execute_this) { 109 for_all_internal<float>(idx_and_data, execute_this); 110 for_all_internal<int32_t>(idx_and_data, execute_this); 111 for_all_internal<uint8_t>(idx_and_data, execute_this); 112} 113 114// Helper template - resize test output per golden 115template <typename ty, size_t tuple_index> 116void resize_accordingly_(const MixedTyped& golden, MixedTyped& test) { 117 std::function<void(int, const std::vector<ty>&)> execute = 118 [&test](int index, const std::vector<ty>& m) { 119 auto& t = std::get<tuple_index>(test); 120 t[index].resize(m.size()); 121 }; 122 for_each<ty>(golden, execute); 123} 124 125inline void resize_accordingly(const MixedTyped& golden, MixedTyped& test) { 126 resize_accordingly_<float, 0>(golden, test); 127 resize_accordingly_<int32_t, 1>(golden, test); 128 resize_accordingly_<uint8_t, 2>(golden, test); 129} 130 131template <typename ty, size_t tuple_index> 132void filter_internal(const MixedTyped& golden, MixedTyped* filtered, 133 std::function<bool(int)> is_ignored) { 134 for_each<ty>(golden, 135 [filtered, &is_ignored](int index, const std::vector<ty>& m) { 136 auto& g = std::get<tuple_index>(*filtered); 137 if (!is_ignored(index)) g[index] = m; 138 }); 139} 140 141inline MixedTyped filter(const MixedTyped& golden, 142 std::function<bool(int)> is_ignored) { 143 MixedTyped filtered; 144 filter_internal<float, 0>(golden, &filtered, is_ignored); 145 filter_internal<int32_t, 1>(golden, &filtered, is_ignored); 146 filter_internal<uint8_t, 2>(golden, &filtered, is_ignored); 147 return filtered; 148} 149 150// Compare results 151#define VECTOR_TYPE(x) \ 152 typename std::tuple_element<x, MixedTyped>::type::mapped_type 153#define VALUE_TYPE(x) VECTOR_TYPE(x)::value_type 154template <size_t tuple_index> 155void compare_( 156 const MixedTyped& golden, const MixedTyped& test, 157 std::function<void(VALUE_TYPE(tuple_index), VALUE_TYPE(tuple_index))> 158 cmp) { 159 for_each<VALUE_TYPE(tuple_index)>( 160 golden, 161 [&test, &cmp](int index, const VECTOR_TYPE(tuple_index) & m) { 162 const auto& test_operands = std::get<tuple_index>(test); 163 const auto& test_ty = test_operands.find(index); 164 ASSERT_NE(test_ty, test_operands.end()); 165 for (unsigned int i = 0; i < m.size(); i++) { 166 SCOPED_TRACE(testing::Message() 167 << "When comparing element " << i); 168 cmp(m[i], test_ty->second[i]); 169 } 170 }); 171} 172#undef VALUE_TYPE 173#undef VECTOR_TYPE 174inline void compare(const MixedTyped& golden, const MixedTyped& test) { 175 compare_<0>(golden, test, 176 [](float g, float t) { EXPECT_NEAR(g, t, 1.e-5f); }); 177 compare_<1>(golden, test, [](int32_t g, int32_t t) { EXPECT_EQ(g, t); }); 178 compare_<2>(golden, test, [](uint8_t g, uint8_t t) { EXPECT_NEAR(g, t, 1); }); 179} 180 181}; // namespace generated_tests 182 183#endif // ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H 184