1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <gtest/gtest.h> 16#include "tensorflow/contrib/lite/interpreter.h" 17#include "tensorflow/contrib/lite/kernels/register.h" 18#include "tensorflow/contrib/lite/kernels/test_util.h" 19#include "tensorflow/contrib/lite/model.h" 20 21namespace tflite { 22namespace { 23 24using ::testing::ElementsAreArray; 25 26class BaseAddOpModel : public SingleOpModel { 27 public: 28 BaseAddOpModel(const TensorData& input1, const TensorData& input2, 29 const TensorData& output, 30 ActivationFunctionType activation_type) { 31 input1_ = AddInput(input1); 32 input2_ = AddInput(input2); 33 output_ = AddOutput(output); 34 SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, 35 CreateAddOptions(builder_, activation_type).Union()); 36 BuildInterpreter({GetShape(input1_), GetShape(input2_)}); 37 } 38 39 int input1() { return input1_; } 40 int input2() { return input2_; } 41 42 protected: 43 int input1_; 44 int input2_; 45 int output_; 46}; 47 48class FloatAddOpModel : public BaseAddOpModel { 49 public: 50 using BaseAddOpModel::BaseAddOpModel; 51 52 std::vector<float> GetOutput() { return ExtractVector<float>(output_); } 53}; 54 55class QuantizedAddOpModel : public BaseAddOpModel { 56 public: 57 using BaseAddOpModel::BaseAddOpModel; 58 59 std::vector<float> GetDequantizedOutput() { 60 return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_), 61 GetScale(output_), GetZeroPoint(output_)); 62 } 63}; 64 65// for quantized Add, the error shouldn't exceed 2*step 66float GetTolerance(int min, int max) { 67 float kQuantizedStep = (max - min) / 255.0; 68 float kQuantizedTolerance = 2.0 * kQuantizedStep; 69 return kQuantizedTolerance; 70} 71 72TEST(FloatAddOpModel, NoActivation) { 73 FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, 74 {TensorType_FLOAT32, {1, 2, 2, 1}}, 75 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); 76 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8}); 77 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5}); 78 m.Invoke(); 79 EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); 80} 81 82TEST(FloatAddOpModel, ActivationRELU_N1_TO_1) { 83 FloatAddOpModel m( 84 {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, 85 {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU_N1_TO_1); 86 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8}); 87 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5}); 88 m.Invoke(); 89 EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 0.4, 1.0, 1.0})); 90} 91 92TEST(FloatAddOpModel, VariousInputShapes) { 93 std::vector<std::initializer_list<int>> test_shapes = { 94 {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; 95 for (int i = 0; i < test_shapes.size(); ++i) { 96 FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, 97 {TensorType_FLOAT32, test_shapes[i]}, 98 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); 99 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); 100 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); 101 m.Invoke(); 102 EXPECT_THAT(m.GetOutput(), 103 ElementsAreArray({-1.9, 0.4, 1.0, 1.3, 2.2, 2.1})) 104 << "With shape number " << i; 105 } 106} 107 108TEST(FloatAddOpModel, WithBroadcast) { 109 std::vector<std::initializer_list<int>> test_shapes = { 110 {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; 111 for (int i = 0; i < test_shapes.size(); ++i) { 112 FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, 113 {TensorType_FLOAT32, {}}, // always a scalar 114 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); 115 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); 116 m.PopulateTensor<float>(m.input2(), {0.1}); 117 m.Invoke(); 118 EXPECT_THAT( 119 m.GetOutput(), 120 ElementsAreArray(ArrayFloatNear({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}))) 121 << "With shape number " << i; 122 } 123} 124 125TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { 126 float kQuantizedTolerance = GetTolerance(-1.0, 1.0); 127 std::vector<std::initializer_list<float>> inputs1 = { 128 {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; 129 std::vector<std::initializer_list<float>> inputs2 = { 130 {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}}; 131 std::vector<std::initializer_list<float>> results = { 132 {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; 133 for (int i = 0; i < inputs1.size(); ++i) { 134 QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, 135 {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, 136 {TensorType_UINT8, {}, -1.0, 1.0}, 137 ActivationFunctionType_NONE); 138 m.QuantizeAndPopulate<uint8_t>(m.input1(), inputs1[i]); 139 m.QuantizeAndPopulate<uint8_t>(m.input2(), inputs2[i]); 140 m.Invoke(); 141 EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( 142 results[i], kQuantizedTolerance))) 143 << "With test number " << i; 144 } 145} 146 147TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { 148 float kQuantizedTolerance = GetTolerance(-1.0, 1.0); 149 std::vector<std::initializer_list<float>> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, 150 {-0.8, 0.2, 0.7, 0.3}}; 151 std::vector<std::initializer_list<float>> inputs2 = {{0.6, 0.4, 0.9, -0.8}, 152 {0.6, 0.4, -0.8, 0.5}}; 153 std::vector<std::initializer_list<float>> results = {{-0.2, 0.6, 1.0, -0.1}, 154 {-0.2, 0.6, -0.1, 0.8}}; 155 for (int i = 0; i < inputs1.size(); ++i) { 156 QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, 157 {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, 158 {TensorType_UINT8, {}, -1.0, 1.0}, 159 ActivationFunctionType_RELU_N1_TO_1); 160 m.QuantizeAndPopulate<uint8_t>(m.input1(), inputs1[i]); 161 m.QuantizeAndPopulate<uint8_t>(m.input2(), inputs2[i]); 162 m.Invoke(); 163 EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( 164 results[i], kQuantizedTolerance))) 165 << "With test number " << i; 166 } 167} 168 169TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { 170 float kQuantizedTolerance = GetTolerance(-3.0, 3.0); 171 std::vector<std::initializer_list<int>> test_shapes = { 172 {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; 173 for (int i = 0; i < test_shapes.size(); ++i) { 174 QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, 175 {TensorType_UINT8, test_shapes[i], -3.0, 3.0}, 176 {TensorType_UINT8, {}, -3.0, 3.0}, 177 ActivationFunctionType_NONE); 178 m.QuantizeAndPopulate<uint8_t>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); 179 m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.1, 0.3, 0.3, 0.5, 1.1, 0.1}); 180 m.Invoke(); 181 EXPECT_THAT(m.GetDequantizedOutput(), 182 ElementsAreArray(ArrayFloatNear({-1.9, 0.5, 1.0, 1.3, 2.2, 2.1}, 183 kQuantizedTolerance))) 184 << "With shape number " << i; 185 } 186} 187 188TEST(QuantizedAddOpModel, QuantizedWithBroadcast) { 189 float kQuantizedTolerance = GetTolerance(-3.0, 3.0); 190 std::vector<std::initializer_list<int>> test_shapes = { 191 {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; 192 for (int i = 0; i < test_shapes.size(); ++i) { 193 QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, 194 {TensorType_UINT8, {}, -3.0, 3.0}, 195 {TensorType_UINT8, {}, -3.0, 3.0}, 196 ActivationFunctionType_NONE); 197 m.QuantizeAndPopulate<uint8_t>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); 198 m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.1}); 199 m.Invoke(); 200 EXPECT_THAT(m.GetDequantizedOutput(), 201 ElementsAreArray(ArrayFloatNear({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}, 202 kQuantizedTolerance))) 203 << "With shape number " << i; 204 } 205} 206 207} // namespace 208} // namespace tflite 209int main(int argc, char** argv) { 210 ::tflite::LogToStderr(); 211 ::testing::InitGoogleTest(&argc, argv); 212 return RUN_ALL_TESTS(); 213} 214