1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15// Unit test for TFLite FULLY_CONNECTED op.
16
17#include <iomanip>
18#include <vector>
19
20#include <gmock/gmock.h>
21#include <gtest/gtest.h>
22#include "tensorflow/contrib/lite/interpreter.h"
23#include "tensorflow/contrib/lite/kernels/register.h"
24#include "tensorflow/contrib/lite/kernels/test_util.h"
25#include "tensorflow/contrib/lite/model.h"
26
27namespace tflite {
28namespace {
29
30using ::testing::ElementsAre;
31using ::testing::ElementsAreArray;
32
33static float fully_connected_input[] = {
34    0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653,
35    0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390,
36    0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314,
37    0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550,
38    0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112,
39    0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999,
40    0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142,
41    0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494,
42    0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081,
43    0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552,
44    0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320,
45    0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458,
46    0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115,
47    0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771,
48    0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582,
49    0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962,
50    0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202,
51    0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691,
52    0.921330, 0.139902};
53
54static float fully_connected_golden_output[] = {
55    0,        0.0732134,   0,        0,          0,         0.280859,
56    0,        0.128927,    0,        0.0777251,  0,         0.270268,
57    0.271435, 0.0173503,   0.335465, 0.235562,
58
59    0,        0.0745866,   0,        0.051611,   0,         0.253876,
60    0,        0.0814873,   0,        0.104104,   0,         0.248529,
61    0.264194, 0,           0.302973, 0.166252,
62
63    0,        0.0170409,   0,        0.0509851,  0,         0.212834,
64    0,        0.0208326,   0,        0.129932,   0.203978,  0.103428,
65    0.298051, 0,           0.332233, 0.00445903,
66
67    0,        0.125246,    0,        0.0735336,  0,         0.0910256,
68    0,        0,           0,        0.18933,    0.378111,  0.0712443,
69    0.277298, 0.0123414,   0.267454, 0,
70
71    0,        0.14687,     0,        0.155495,   0.0300215, 0.147256,
72    0,        0,           0,        0.156412,   0.434914,  0.0461529,
73    0.246508, 0,           0.363138, 0,
74
75    0,        0,           0,        0.0212949,  0,         0.301708,
76    0,        0.35497,     0,        0.406223,   0.0260211, 0.049195,
77    0.197161, 0,           0.37316,  0,
78
79    0,        0.221783,    0,        0,          0.0116515, 0.281945,
80    0,        0,           0,        0,          0.285626,  0.181773,
81    0.296401, 0.170452,    0.367135, 0.142597,
82
83    0,        0,           0,        0,          0,         0.418886,
84    0,        0.291063,    0,        0.227541,   0.0424759, 0.27589,
85    0.398286, 0.177146,    0.40359,  0.121452,
86
87    0,        0.0834884,   0,        0,          0,         0.287441,
88    0,        0.0046838,   0,        0.0122087,  0,         0.217376,
89    0.140183, 0.0948412,   0.436677, 0.0589876,
90
91    0,        0.0289969,   0,        0.0921397,  0,         0.396802,
92    0,        0.0126157,   0,        0.0968433,  0,         0.172271,
93    0.173295, 0.0664741,   0.53645,  0.00915603,
94
95    0,        0,           0,        0,          0,         0.147942,
96    0,        0.263795,    0,        0.39782,    0,         0.382435,
97    0.561072, 0.0579847,   0.145712, 0.13508,
98
99    0,        0,           0,        0.16382,    0,         0.322294,
100    0,        0.163798,    0,        0.405211,   0.367953,  0.076852,
101    0.342473, 0.0834118,   0.377537, 0,
102
103    0,        0.206,       0,        0,          0,         0.375769,
104    0,        0,           0,        0,          0,         0.125165,
105    0,        0.105591,    0.52055,  0.0536445,
106
107    0,        0.259261,    0,        0,          0,         0.247707,
108    0,        0,           0,        0,          0,         0.215862,
109    0.149153, 0.224678,    0.359519, 0.129419,
110
111    0,        0.17611,     0,        0.280895,   0,         0.576484,
112    0,        0.000418848, 0,        0,          0,         0.151112,
113    0.211902, 0,           0.566341, 0.106305,
114
115    0,        0.0246284,   0,        0,          0,         0.196267,
116    0,        0.0248624,   0,        0.265635,   0,         0.436199,
117    0.408079, 0.134514,    0.328489, 0.411368};
118
119class BaseFullyConnectedOpModel : public SingleOpModel {
120 public:
121  // TODO(ahentz): test different activation types too.
122  BaseFullyConnectedOpModel(int units, int batches, const TensorData& input,
123                            const TensorData& output = {TensorType_FLOAT32})
124      : batches_(batches), units_(units) {
125    int total_input_size = 1;
126    for (int i = 0; i < input.shape.size(); ++i) {
127      total_input_size *= input.shape[i];
128    }
129    input_size_ = total_input_size / batches_;
130
131    input_ = AddInput(input);
132    weights_ =
133        AddInput({input.type, {units_, input_size_}, input.min, input.max});
134
135    if (input.type == TensorType_FLOAT32) {
136      bias_ = AddInput({TensorType_FLOAT32, {units_}});
137    } else {
138      // This is a quantized version. The scale of 'bias' depends on the scales
139      // of input and filter. Supposedly this is correctly set during quantized
140      // training.
141      auto bias_scale = GetScale(input_) * GetScale(weights_);
142      TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale};
143      bias_ = AddInput(bias);
144    }
145
146    output_ = AddOutput(output);
147
148    SetBuiltinOp(
149        BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
150        CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
151            .Union());
152    BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
153  }
154
155  int input_size() { return input_size_; }
156  int num_units() { return units_; }
157  int num_batches() { return batches_; }
158
159 protected:
160  int input_;
161  int weights_;
162  int bias_;
163  int output_;
164
165  int batches_;
166  int units_;
167  int input_size_;
168};
169
170class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel {
171 public:
172  using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
173
174  void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
175
176  void SetWeights(std::initializer_list<float> f) {
177    PopulateTensor(weights_, f);
178  }
179
180  void SetInput(std::initializer_list<float> data) {
181    PopulateTensor(input_, data);
182  }
183  void SetInput(int offset, float* begin, float* end) {
184    PopulateTensor(input_, offset, begin, end);
185  }
186
187  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
188};
189
190class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
191 public:
192  using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
193
194  void SetBias(std::initializer_list<float> data) {
195    QuantizeAndPopulate<int32_t>(bias_, data);
196  }
197  void SetWeights(std::initializer_list<float> data) {
198    QuantizeAndPopulate<uint8_t>(weights_, data);
199  }
200  void SetInput(std::initializer_list<float> data) {
201    QuantizeAndPopulate<uint8_t>(input_, data);
202  }
203
204  std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
205  std::vector<float> GetDequantizedOutput() {
206    return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
207                               GetScale(output_), GetZeroPoint(output_));
208  }
209};
210
211// TODO(ahentz): add more small tests like this one, focused on making sure the
212// calculations are correct.
213TEST(FullyConnectedOpTest, SimpleTest) {
214  FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}});
215  m.SetWeights({
216      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
217      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
218      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
219  });
220  m.SetBias({1, 2, 3});
221
222  m.SetInput({
223      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
224      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
225  });
226
227  m.Invoke();
228
229  EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
230}
231
232TEST(FullyConnectedOpTest, SimpleTestQuantized) {
233  QuantizedFullyConnectedOpModel m(
234      3, 2,
235      /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
236      /*output=*/{TensorType_UINT8, {}, -127, 128});
237
238  // input_product_scale < output_scale was not true.
239  m.SetWeights({
240      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
241      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
242      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
243  });
244  m.SetBias({1, 2, 3});
245
246  m.SetInput({
247      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
248      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
249  });
250
251  m.Invoke();
252
253  EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
254                                            24, 25, 26,  //
255                                            58, 59, 60,  //
256                                        })));
257  EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
258}
259
260TEST(FullyConnectedOpTest, SimpleTest4DInput) {
261  // Note that it is not required that the first dimension be the number of
262  // batches. All we care is that the input can be evenly distributed in
263  // batches. In this case, we need the input to have multiples of '2'.
264  FloatFullyConnectedOpModel m(/*units=*/3,
265                               /*batches=*/2,
266                               /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
267  m.SetWeights({
268      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
269      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
270      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
271  });
272  m.SetBias({1, 2, 3});
273
274  m.SetInput({
275      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // first batch
276      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // second batch
277  });
278
279  m.Invoke();
280
281  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
282                                 24, 25, 26,  // first batch
283                                 58, 59, 60,  // second batch
284                             }));
285}
286
287TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
288  QuantizedFullyConnectedOpModel m(
289      3, 2,
290      /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
291      /*output=*/{TensorType_UINT8, {}, -127, 128});
292
293  // input_product_scale < output_scale was not true.
294  m.SetWeights({
295      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
296      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
297      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
298  });
299  m.SetBias({1, 2, 3});
300
301  m.SetInput({
302      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
303      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
304  });
305
306  m.Invoke();
307
308  EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
309                                            24, 25, 26,  //
310                                            58, 59, 60,  //
311                                        })));
312  EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
313}
314
315// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
316// to debug errors and doesn't necessarily test all the important details.
317TEST(FullyConnectedOpTest, BlackBoxTest) {
318  FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}});
319  m.SetWeights(
320      {0.091327,  0.103366,  -0.316505, -0.083120, 0.149366,  -0.196636,
321       -0.123672, 0.062800,  0.063031,  0.191670,  -0.062001, -0.061504,
322       -0.275581, 0.059388,  -0.118497, -0.079224, 0.109758,  0.008307,
323       -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650,
324       0.266455,  0.051517,  -0.123448, 0.322464,  0.043282,  -0.173782,
325       -0.190381, 0.002013,  0.096086,  0.131157,  0.031164,  0.100638,
326       -0.312191, -0.080923, -0.101318, -0.116614, 0.142238,  0.086540,
327       -0.139154, 0.174268,  -0.073161, 0.080072,  0.006874,  0.229382,
328       -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824,
329       -0.025021, 0.074460,  -0.252595, -0.161750, -0.136403, 0.008308,
330       0.005710,  0.096600,  0.289839,  0.218816,  -0.304651, -0.070958,
331       0.054598,  0.147113,  -0.139112, -0.072798, -0.163335, -0.167863,
332       -0.128762, -0.035780, 0.117262,  0.017177,  0.263335,  -0.176612,
333       0.262961,  -0.093654, -0.339283, 0.333071,  0.180827,  0.287583,
334       0.066350,  -0.197947, -0.114449, -0.236035, 0.103532,  -0.034284,
335       0.093299,  -0.145361, 0.054001,  0.250570,  0.157010,  -0.143480,
336       -0.139061, -0.048873, 0.067557,  0.139038,  0.324106,  0.227041,
337       0.037793,  -0.225747, -0.241619, 0.357835,  0.135762,  -0.306764,
338       -0.125982, 0.091916,  0.266587,  0.030135,  0.265148,  0.141627,
339       0.020120,  0.083815,  -0.124556, -0.100124, -0.048159, 0.181172,
340       0.302309,  -0.041084, 0.146334,  -0.061511, -0.232605, 0.281324,
341       0.145408,  -0.221897});
342  m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860,
343             0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804,
344             0.048478, -0.032270, 0.175688, -0.085662});
345
346  const int input_sequence_size = sizeof(fully_connected_input) /
347                                  sizeof(float) /
348                                  (m.input_size() * m.num_batches());
349  for (int i = 0; i < input_sequence_size; i++) {
350    // TODO(ahentz): This is what the original test was doing: two equal
351    // batches per invocation. We could instead use two different batches.
352    float* batch_start = fully_connected_input + i * m.input_size();
353    float* batch_end = batch_start + m.input_size();
354    m.SetInput(0, batch_start, batch_end);
355    m.SetInput(m.input_size(), batch_start, batch_end);
356
357    m.Invoke();
358
359    float* golden_start = fully_connected_golden_output + i * m.num_units();
360    float* golden_end = golden_start + m.num_units();
361    std::vector<float> expected;
362    expected.insert(expected.end(), golden_start, golden_end);
363    expected.insert(expected.end(), golden_start, golden_end);
364
365    EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
366  }
367}
368
369}  // namespace
370}  // namespace tflite
371
372int main(int argc, char** argv) {
373  ::tflite::LogToStderr();
374  ::testing::InitGoogleTest(&argc, argv);
375  return RUN_ALL_TESTS();
376}
377