1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15// Unit test for TFLite FULLY_CONNECTED op. 16 17#include <iomanip> 18#include <vector> 19 20#include <gmock/gmock.h> 21#include <gtest/gtest.h> 22#include "tensorflow/contrib/lite/interpreter.h" 23#include "tensorflow/contrib/lite/kernels/register.h" 24#include "tensorflow/contrib/lite/kernels/test_util.h" 25#include "tensorflow/contrib/lite/model.h" 26 27namespace tflite { 28namespace { 29 30using ::testing::ElementsAre; 31using ::testing::ElementsAreArray; 32 33static float fully_connected_input[] = { 34 0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653, 35 0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390, 36 0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314, 37 0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550, 38 0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112, 39 0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999, 40 0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142, 41 0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494, 42 0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081, 43 0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552, 44 0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320, 45 0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458, 46 0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115, 47 0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771, 48 0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582, 49 0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962, 50 0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202, 51 0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691, 52 0.921330, 0.139902}; 53 54static float fully_connected_golden_output[] = { 55 0, 0.0732134, 0, 0, 0, 0.280859, 56 0, 0.128927, 0, 0.0777251, 0, 0.270268, 57 0.271435, 0.0173503, 0.335465, 0.235562, 58 59 0, 0.0745866, 0, 0.051611, 0, 0.253876, 60 0, 0.0814873, 0, 0.104104, 0, 0.248529, 61 0.264194, 0, 0.302973, 0.166252, 62 63 0, 0.0170409, 0, 0.0509851, 0, 0.212834, 64 0, 0.0208326, 0, 0.129932, 0.203978, 0.103428, 65 0.298051, 0, 0.332233, 0.00445903, 66 67 0, 0.125246, 0, 0.0735336, 0, 0.0910256, 68 0, 0, 0, 0.18933, 0.378111, 0.0712443, 69 0.277298, 0.0123414, 0.267454, 0, 70 71 0, 0.14687, 0, 0.155495, 0.0300215, 0.147256, 72 0, 0, 0, 0.156412, 0.434914, 0.0461529, 73 0.246508, 0, 0.363138, 0, 74 75 0, 0, 0, 0.0212949, 0, 0.301708, 76 0, 0.35497, 0, 0.406223, 0.0260211, 0.049195, 77 0.197161, 0, 0.37316, 0, 78 79 0, 0.221783, 0, 0, 0.0116515, 0.281945, 80 0, 0, 0, 0, 0.285626, 0.181773, 81 0.296401, 0.170452, 0.367135, 0.142597, 82 83 0, 0, 0, 0, 0, 0.418886, 84 0, 0.291063, 0, 0.227541, 0.0424759, 0.27589, 85 0.398286, 0.177146, 0.40359, 0.121452, 86 87 0, 0.0834884, 0, 0, 0, 0.287441, 88 0, 0.0046838, 0, 0.0122087, 0, 0.217376, 89 0.140183, 0.0948412, 0.436677, 0.0589876, 90 91 0, 0.0289969, 0, 0.0921397, 0, 0.396802, 92 0, 0.0126157, 0, 0.0968433, 0, 0.172271, 93 0.173295, 0.0664741, 0.53645, 0.00915603, 94 95 0, 0, 0, 0, 0, 0.147942, 96 0, 0.263795, 0, 0.39782, 0, 0.382435, 97 0.561072, 0.0579847, 0.145712, 0.13508, 98 99 0, 0, 0, 0.16382, 0, 0.322294, 100 0, 0.163798, 0, 0.405211, 0.367953, 0.076852, 101 0.342473, 0.0834118, 0.377537, 0, 102 103 0, 0.206, 0, 0, 0, 0.375769, 104 0, 0, 0, 0, 0, 0.125165, 105 0, 0.105591, 0.52055, 0.0536445, 106 107 0, 0.259261, 0, 0, 0, 0.247707, 108 0, 0, 0, 0, 0, 0.215862, 109 0.149153, 0.224678, 0.359519, 0.129419, 110 111 0, 0.17611, 0, 0.280895, 0, 0.576484, 112 0, 0.000418848, 0, 0, 0, 0.151112, 113 0.211902, 0, 0.566341, 0.106305, 114 115 0, 0.0246284, 0, 0, 0, 0.196267, 116 0, 0.0248624, 0, 0.265635, 0, 0.436199, 117 0.408079, 0.134514, 0.328489, 0.411368}; 118 119class BaseFullyConnectedOpModel : public SingleOpModel { 120 public: 121 // TODO(ahentz): test different activation types too. 122 BaseFullyConnectedOpModel(int units, int batches, const TensorData& input, 123 const TensorData& output = {TensorType_FLOAT32}) 124 : batches_(batches), units_(units) { 125 int total_input_size = 1; 126 for (int i = 0; i < input.shape.size(); ++i) { 127 total_input_size *= input.shape[i]; 128 } 129 input_size_ = total_input_size / batches_; 130 131 input_ = AddInput(input); 132 weights_ = 133 AddInput({input.type, {units_, input_size_}, input.min, input.max}); 134 135 if (input.type == TensorType_FLOAT32) { 136 bias_ = AddInput({TensorType_FLOAT32, {units_}}); 137 } else { 138 // This is a quantized version. The scale of 'bias' depends on the scales 139 // of input and filter. Supposedly this is correctly set during quantized 140 // training. 141 auto bias_scale = GetScale(input_) * GetScale(weights_); 142 TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; 143 bias_ = AddInput(bias); 144 } 145 146 output_ = AddOutput(output); 147 148 SetBuiltinOp( 149 BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, 150 CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) 151 .Union()); 152 BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); 153 } 154 155 int input_size() { return input_size_; } 156 int num_units() { return units_; } 157 int num_batches() { return batches_; } 158 159 protected: 160 int input_; 161 int weights_; 162 int bias_; 163 int output_; 164 165 int batches_; 166 int units_; 167 int input_size_; 168}; 169 170class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel { 171 public: 172 using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; 173 174 void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); } 175 176 void SetWeights(std::initializer_list<float> f) { 177 PopulateTensor(weights_, f); 178 } 179 180 void SetInput(std::initializer_list<float> data) { 181 PopulateTensor(input_, data); 182 } 183 void SetInput(int offset, float* begin, float* end) { 184 PopulateTensor(input_, offset, begin, end); 185 } 186 187 std::vector<float> GetOutput() { return ExtractVector<float>(output_); } 188}; 189 190class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { 191 public: 192 using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; 193 194 void SetBias(std::initializer_list<float> data) { 195 QuantizeAndPopulate<int32_t>(bias_, data); 196 } 197 void SetWeights(std::initializer_list<float> data) { 198 QuantizeAndPopulate<uint8_t>(weights_, data); 199 } 200 void SetInput(std::initializer_list<float> data) { 201 QuantizeAndPopulate<uint8_t>(input_, data); 202 } 203 204 std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); } 205 std::vector<float> GetDequantizedOutput() { 206 return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_), 207 GetScale(output_), GetZeroPoint(output_)); 208 } 209}; 210 211// TODO(ahentz): add more small tests like this one, focused on making sure the 212// calculations are correct. 213TEST(FullyConnectedOpTest, SimpleTest) { 214 FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}}); 215 m.SetWeights({ 216 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 217 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 218 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 219 }); 220 m.SetBias({1, 2, 3}); 221 222 m.SetInput({ 223 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 224 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 225 }); 226 227 m.Invoke(); 228 229 EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); 230} 231 232TEST(FullyConnectedOpTest, SimpleTestQuantized) { 233 QuantizedFullyConnectedOpModel m( 234 3, 2, 235 /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, 236 /*output=*/{TensorType_UINT8, {}, -127, 128}); 237 238 // input_product_scale < output_scale was not true. 239 m.SetWeights({ 240 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 241 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 242 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 243 }); 244 m.SetBias({1, 2, 3}); 245 246 m.SetInput({ 247 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 248 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 249 }); 250 251 m.Invoke(); 252 253 EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ 254 24, 25, 26, // 255 58, 59, 60, // 256 }))); 257 EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); 258} 259 260TEST(FullyConnectedOpTest, SimpleTest4DInput) { 261 // Note that it is not required that the first dimension be the number of 262 // batches. All we care is that the input can be evenly distributed in 263 // batches. In this case, we need the input to have multiples of '2'. 264 FloatFullyConnectedOpModel m(/*units=*/3, 265 /*batches=*/2, 266 /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}}); 267 m.SetWeights({ 268 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 269 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 270 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 271 }); 272 m.SetBias({1, 2, 3}); 273 274 m.SetInput({ 275 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // first batch 276 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // second batch 277 }); 278 279 m.Invoke(); 280 281 EXPECT_THAT(m.GetOutput(), ElementsAreArray({ 282 24, 25, 26, // first batch 283 58, 59, 60, // second batch 284 })); 285} 286 287TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) { 288 QuantizedFullyConnectedOpModel m( 289 3, 2, 290 /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64}, 291 /*output=*/{TensorType_UINT8, {}, -127, 128}); 292 293 // input_product_scale < output_scale was not true. 294 m.SetWeights({ 295 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 296 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 297 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 298 }); 299 m.SetBias({1, 2, 3}); 300 301 m.SetInput({ 302 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 303 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 304 }); 305 306 m.Invoke(); 307 308 EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ 309 24, 25, 26, // 310 58, 59, 60, // 311 }))); 312 EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); 313} 314 315// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard 316// to debug errors and doesn't necessarily test all the important details. 317TEST(FullyConnectedOpTest, BlackBoxTest) { 318 FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}}); 319 m.SetWeights( 320 {0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636, 321 -0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504, 322 -0.275581, 0.059388, -0.118497, -0.079224, 0.109758, 0.008307, 323 -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650, 324 0.266455, 0.051517, -0.123448, 0.322464, 0.043282, -0.173782, 325 -0.190381, 0.002013, 0.096086, 0.131157, 0.031164, 0.100638, 326 -0.312191, -0.080923, -0.101318, -0.116614, 0.142238, 0.086540, 327 -0.139154, 0.174268, -0.073161, 0.080072, 0.006874, 0.229382, 328 -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824, 329 -0.025021, 0.074460, -0.252595, -0.161750, -0.136403, 0.008308, 330 0.005710, 0.096600, 0.289839, 0.218816, -0.304651, -0.070958, 331 0.054598, 0.147113, -0.139112, -0.072798, -0.163335, -0.167863, 332 -0.128762, -0.035780, 0.117262, 0.017177, 0.263335, -0.176612, 333 0.262961, -0.093654, -0.339283, 0.333071, 0.180827, 0.287583, 334 0.066350, -0.197947, -0.114449, -0.236035, 0.103532, -0.034284, 335 0.093299, -0.145361, 0.054001, 0.250570, 0.157010, -0.143480, 336 -0.139061, -0.048873, 0.067557, 0.139038, 0.324106, 0.227041, 337 0.037793, -0.225747, -0.241619, 0.357835, 0.135762, -0.306764, 338 -0.125982, 0.091916, 0.266587, 0.030135, 0.265148, 0.141627, 339 0.020120, 0.083815, -0.124556, -0.100124, -0.048159, 0.181172, 340 0.302309, -0.041084, 0.146334, -0.061511, -0.232605, 0.281324, 341 0.145408, -0.221897}); 342 m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860, 343 0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804, 344 0.048478, -0.032270, 0.175688, -0.085662}); 345 346 const int input_sequence_size = sizeof(fully_connected_input) / 347 sizeof(float) / 348 (m.input_size() * m.num_batches()); 349 for (int i = 0; i < input_sequence_size; i++) { 350 // TODO(ahentz): This is what the original test was doing: two equal 351 // batches per invocation. We could instead use two different batches. 352 float* batch_start = fully_connected_input + i * m.input_size(); 353 float* batch_end = batch_start + m.input_size(); 354 m.SetInput(0, batch_start, batch_end); 355 m.SetInput(m.input_size(), batch_start, batch_end); 356 357 m.Invoke(); 358 359 float* golden_start = fully_connected_golden_output + i * m.num_units(); 360 float* golden_end = golden_start + m.num_units(); 361 std::vector<float> expected; 362 expected.insert(expected.end(), golden_start, golden_end); 363 expected.insert(expected.end(), golden_start, golden_end); 364 365 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); 366 } 367} 368 369} // namespace 370} // namespace tflite 371 372int main(int argc, char** argv) { 373 ::tflite::LogToStderr(); 374 ::testing::InitGoogleTest(&argc, argv); 375 return RUN_ALL_TESTS(); 376} 377