1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <gtest/gtest.h> 16#include "tensorflow/contrib/lite/interpreter.h" 17#include "tensorflow/contrib/lite/kernels/register.h" 18#include "tensorflow/contrib/lite/kernels/test_util.h" 19#include "tensorflow/contrib/lite/model.h" 20 21namespace tflite { 22namespace { 23 24using ::testing::ElementsAreArray; 25 26class BaseMulOpModel : public SingleOpModel { 27 public: 28 BaseMulOpModel(const TensorData& input1, const TensorData& input2, 29 const TensorData& output, 30 ActivationFunctionType activation_type) { 31 input1_ = AddInput(input1); 32 input2_ = AddInput(input2); 33 output_ = AddOutput(output); 34 SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, 35 CreateMulOptions(builder_, activation_type).Union()); 36 BuildInterpreter({GetShape(input1_), GetShape(input2_)}); 37 } 38 39 int input1() { return input1_; } 40 int input2() { return input2_; } 41 42 protected: 43 int input1_; 44 int input2_; 45 int output_; 46}; 47 48class FloatMulOpModel : public BaseMulOpModel { 49 public: 50 using BaseMulOpModel::BaseMulOpModel; 51 52 std::vector<float> GetOutput() { return ExtractVector<float>(output_); } 53}; 54 55// For quantized Mul, the error shouldn't exceed (2*step + step^2). 56// The param min=-1.0 & max=1.0 is used in the following tests. 57// The tolerance value is ~0.0157. 58const float kQuantizedStep = 2.0 / 255.0; 59const float kQuantizedTolerance = 60 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep; 61 62class QuantizedMulOpModel : public BaseMulOpModel { 63 public: 64 using BaseMulOpModel::BaseMulOpModel; 65 66 std::vector<float> GetDequantizedOutput() { 67 return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_), 68 GetScale(output_), GetZeroPoint(output_)); 69 } 70}; 71 72TEST(FloatMulOpTest, NoActivation) { 73 FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, 74 {TensorType_FLOAT32, {1, 2, 2, 1}}, 75 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); 76 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8}); 77 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5}); 78 m.Invoke(); 79 EXPECT_THAT(m.GetOutput(), 80 ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); 81} 82 83TEST(FloatMulOpTest, ActivationRELU_N1_TO_1) { 84 FloatMulOpModel m( 85 {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, 86 {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU_N1_TO_1); 87 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8}); 88 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 5}); 89 m.Invoke(); 90 EXPECT_THAT(m.GetOutput(), 91 ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 1.0}))); 92} 93 94TEST(FloatMulOpTest, VariousInputShapes) { 95 std::vector<std::initializer_list<int>> test_shapes = { 96 {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; 97 for (int i = 0; i < test_shapes.size(); ++i) { 98 FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, 99 {TensorType_FLOAT32, test_shapes[i]}, 100 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); 101 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); 102 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); 103 m.Invoke(); 104 EXPECT_THAT( 105 m.GetOutput(), 106 ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4, 1.21, 0.2}))) 107 << "With shape number " << i; 108 } 109} 110 111TEST(FloatMulOpTest, WithBroadcast) { 112 std::vector<std::initializer_list<int>> test_shapes = { 113 {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; 114 for (int i = 0; i < test_shapes.size(); ++i) { 115 FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, 116 {TensorType_FLOAT32, {}}, // always a scalar 117 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); 118 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); 119 m.PopulateTensor<float>(m.input2(), {0.1}); 120 m.Invoke(); 121 EXPECT_THAT( 122 m.GetOutput(), 123 ElementsAreArray(ArrayFloatNear({-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}))) 124 << "With shape number " << i; 125 } 126} 127 128TEST(QuantizedMulOpTest, NoActivation) { 129 QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, 130 {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, 131 {TensorType_UINT8, {}, -1.0, 1.0}, 132 ActivationFunctionType_NONE); 133 m.QuantizeAndPopulate<uint8_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7}); 134 m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.6, 0.4, 0.9, 0.8}); 135 m.Invoke(); 136 EXPECT_THAT(m.GetDequantizedOutput(), 137 ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, 138 kQuantizedTolerance))); 139} 140 141// for quantized Mul, the error shouldn't exceed 2*step 142float GetTolerance(int min, int max) { 143 float kQuantizedStep = (max - min) / 255.0; 144 float kQuantizedTolerance = 2.0 * kQuantizedStep; 145 return kQuantizedTolerance; 146} 147 148TEST(QuantizedMulOpTest, WithBroadcast) { 149 float kQuantizedTolerance = GetTolerance(-3.0, 3.0); 150 std::vector<std::initializer_list<int>> test_shapes = { 151 {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; 152 for (int i = 0; i < test_shapes.size(); ++i) { 153 QuantizedMulOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, 154 {TensorType_UINT8, {}, -3.0, 3.0}, // always a scalar 155 {TensorType_UINT8, {}, -3.0, 3.0}, 156 ActivationFunctionType_NONE); 157 m.QuantizeAndPopulate<uint8_t>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); 158 m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.1}); 159 m.Invoke(); 160 EXPECT_THAT(m.GetDequantizedOutput(), 161 ElementsAreArray(ArrayFloatNear( 162 {-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}, kQuantizedTolerance))) 163 << "With shape number " << i; 164 } 165} 166 167} // namespace 168} // namespace tflite 169 170int main(int argc, char** argv) { 171 ::tflite::LogToStderr(); 172 ::testing::InitGoogleTest(&argc, argv); 173 return RUN_ALL_TESTS(); 174} 175