1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <cstdarg>
16#include <gtest/gtest.h>
17#include "tensorflow/contrib/lite/interpreter.h"
18#include "tensorflow/contrib/lite/kernels/register.h"
19#include "tensorflow/contrib/lite/kernels/test_util.h"
20#include "tensorflow/contrib/lite/model.h"
21
22namespace tflite {
23namespace {
24
25using ::testing::ElementsAreArray;
26
27class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
28 public:
29  // TODO(ahentz): Also test different activation types, bias, padding types,
30  // stride values.
31  BaseDepthwiseConvolutionOpModel(const TensorData& input,
32                                  const TensorData& filter,
33                                  const TensorData& output) {
34    input_ = AddInput(input);
35    filter_ = AddInput(filter);
36
37    int bias_size = GetShape(filter_)[3];
38    if (input.type == TensorType_FLOAT32) {
39      bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
40    } else {
41      // This is a quantized version. The scale of 'bias' depends on the scales
42      // of input and filter. Supposedly this is correctly set during quantized
43      // training.
44      auto bias_scale = GetScale(input_) * GetScale(filter_);
45      TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
46      bias_ = AddInput(bias);
47    }
48
49    output_ = AddOutput(output);
50    if (input.type != TensorType_FLOAT32) {
51      // The following is required by quantized inference. It is the unittest's
52      // responsibility to make sure the output scale falls into the correct
53      // range.
54      CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
55    }
56
57    int input_depth = GetShape(input_)[3];
58    int output_depth = GetShape(filter_)[3];
59    int depth_mul = output_depth / input_depth;
60
61    SetBuiltinOp(
62        BuiltinOperator_DEPTHWISE_CONV_2D,
63        BuiltinOptions_DepthwiseConv2DOptions,
64        CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
65                                     ActivationFunctionType_NONE)
66            .Union());
67
68    BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
69  }
70
71 protected:
72  int input_;
73  int filter_;
74  int bias_;
75  int output_;
76};
77
78class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel {
79 public:
80  using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel;
81
82  void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
83
84  void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
85
86  void SetInput(std::initializer_list<float> data) {
87    PopulateTensor(input_, data);
88  }
89
90  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
91};
92
93TEST(DepthwiseConvolutionOpTest, SimpleTest) {
94  DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
95                                {TensorType_FLOAT32, {1, 2, 2, 4}},
96                                {TensorType_FLOAT32, {}});
97
98  m.SetInput({
99      1, 2, 7, 8,    // column 1
100      3, 4, 9, 10,   // column 2
101      5, 6, 11, 12,  // column 3
102  });
103  m.SetFilter({
104      1, 2, 3, 4,        //
105      -9, 10, -11, 12,   //
106      5, 6, 7, 8,        //
107      13, -14, 15, -16,  //
108  });
109  m.SetBias({1, 2, 3, 4});
110
111  m.Invoke();
112
113  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
114                                 71, -34, 99, -20,  //
115                                 91, -26, 127, -4,  //
116                             }));
117}
118
119class QuantizedDepthwiseConvolutionOpModel
120    : public BaseDepthwiseConvolutionOpModel {
121 public:
122  using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel;
123
124  void SetInput(std::initializer_list<float> data) {
125    QuantizeAndPopulate<uint8_t>(input_, data);
126  }
127
128  void SetFilter(std::initializer_list<float> data) {
129    QuantizeAndPopulate<uint8_t>(filter_, data);
130  }
131
132  void SetBias(std::initializer_list<float> data) {
133    QuantizeAndPopulate<int32_t>(bias_, data);
134  }
135
136  std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
137  std::vector<float> GetDequantizedOutput() {
138    return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
139                               GetScale(output_), GetZeroPoint(output_));
140  }
141};
142
143// In this test we set the input and output scales so that the results match
144// exactly the 'non-quantized' version.
145TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
146  QuantizedDepthwiseConvolutionOpModel m(
147      {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
148      {TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64},
149      {TensorType_UINT8, {}, -127, 128});
150
151  m.SetInput({
152      1, 2, 7, 8,    // column 1
153      3, 4, 9, 10,   // column 2
154      5, 6, 11, 12,  // column 3
155  });
156  m.SetFilter({
157      1, 2, 3, 4,        //
158      -9, 10, -11, 12,   //
159      5, 6, 7, 8,        //
160      13, -14, 15, -16,  //
161  });
162  m.SetBias({1, 2, 3, 4});
163
164  m.Invoke();
165
166  EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
167                                            {
168                                                71, -34, 99, -20,  //
169                                                91, -26, 127, -4,  //
170                                            },
171                                            1e-5)));
172  // For good  measure, let's also verify the quantized values:
173  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
174                                 198, 93, 226, 107,   //
175                                 218, 101, 254, 123,  //
176                             }));
177}
178
179}  // namespace
180}  // namespace tflite
181
182int main(int argc, char** argv) {
183  ::tflite::LogToStderr();
184  ::testing::InitGoogleTest(&argc, argv);
185  return RUN_ALL_TESTS();
186}
187