10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <cstdarg>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <gtest/gtest.h>
1814e0e7fe1eafd286f3813ba839b5f3236394a0a1Anna R#include "absl/memory/memory.h"
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/interpreter.h"
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/kernels/register.h"
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/kernels/test_util.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/model.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace tflite {
25e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling
26e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Lingnamespace ops {
27e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Lingnamespace builtin {
28e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling
29e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTfLiteRegistration* Register_CONVOLUTION_REF();
30e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT();
31e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT();
321a92f45677ee66af24f2219c6b1cbaeee87056b7Yu-Cheng LingTfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT();
33e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling
34e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling}  // namespace builtin
35e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling}  // namespace ops
36e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace {
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleusing ::testing::ElementsAreArray;
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleclass BaseConvolutionOpModel : public SingleOpModel {
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle public:
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // TODO(ahentz): Also test different activation types, bias, padding types,
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // stride values.
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  BaseConvolutionOpModel(
46e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling      TfLiteRegistration* registration, const TensorData& input,
47e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling      const TensorData& filter, const TensorData& output, int stride_width = 2,
48e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling      int stride_height = 2, enum Padding padding = Padding_VALID,
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      enum ActivationFunctionType activation = ActivationFunctionType_NONE) {
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    input_ = AddInput(input);
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    filter_ = AddInput(filter);
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    int bias_size = GetShape(filter_)[0];
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    if (input.type == TensorType_FLOAT32) {
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    } else {
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // This is a quantized version. The scale of 'bias' depends on the scales
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // of input and filter. Supposedly this is correctly set during quantized
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // training.
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      auto bias_scale = GetScale(input_) * GetScale(filter_);
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      bias_ = AddInput(bias);
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    output_ = AddOutput(output);
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    if (input.type != TensorType_FLOAT32) {
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // The following is required by quantized inference. It is the unittest's
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // responsibility to make sure the output scale falls into the correct
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // range.
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                 CreateConv2DOptions(builder_, padding, stride_width,
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                     stride_height, activation)
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                     .Union());
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
78e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling    resolver_ = absl::make_unique<SingleOpResolver>(BuiltinOperator_CONV_2D,
79e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling                                                    registration);
800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle protected:
840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  int input_;
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  int filter_;
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  int bias_;
870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  int output_;
880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle};
890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleclass ConvolutionOpModel : public BaseConvolutionOpModel {
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle public:
920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  using BaseConvolutionOpModel::BaseConvolutionOpModel;
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  void SetInput(std::initializer_list<float> data) {
990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    PopulateTensor(input_, data);
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle};
1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
104e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Lingconst auto kKernelMap = new std::map<string, TfLiteRegistration*>({
105e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling    {"Reference", ops::builtin::Register_CONVOLUTION_REF()},
106e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling    {"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()},
107e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling    {"MultithreadedOptimized",
108e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling     ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()},
1091a92f45677ee66af24f2219c6b1cbaeee87056b7Yu-Cheng Ling    {"CblasOptimized", ops::builtin::Register_CONVOLUTION_CBLAS_OPT()},
110e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling});
111e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling
112e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Lingclass ConvolutionOpTest : public SingleOpTest {
113e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling protected:
114e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling  const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
115e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling    return *kKernelMap;
116e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling  }
117e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling};
118e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling
119e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, SimpleTestFloat32) {
120e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling  ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
1210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                       {TensorType_FLOAT32, {3, 2, 2, 1}},
1220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                       {TensorType_FLOAT32, {}});
1230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetInput({
1250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // First batch
1260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      1, 1, 1, 1,  // row = 1
1270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      2, 2, 2, 2,  // row = 2
1280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // Second batch
1290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      1, 2, 3, 4,  // row = 1
1300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      1, 2, 3, 4,  // row = 2
1310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  });
1320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetFilter({
1330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      1, 2, 3, 4,    // first 2x2 filter
1340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      -1, 1, -1, 1,  // second 2x2 filter
1350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      -1, -1, 1, 1,  // third 2x2 filter
1360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  });
1370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetBias({1, 2, 3});
1380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.Invoke();
1400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 18, 2, 5,  // first batch, left
1430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 18, 2, 5,  // first batch, right
1440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 17, 4, 3,  // second batch, left
1450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 37, 4, 3,  // second batch, right
1460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                             }));
1470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
149e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
150e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling  ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}},
1510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                       {TensorType_FLOAT32, {1, 2, 2, 1}},
1520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                       {TensorType_FLOAT32, {}},
1530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                       /*stride_width=*/3, /*stride_height=*/1);
1540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetInput({
1550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      3, 2, 1, -1, -2, -3,  //
1560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      4, 3, 2, -2, -3, -4,  //
1570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      5, 4, 3, -3, -4, -5,  //
1580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  });
1590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetFilter({
1600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      1, 2,  //
1610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      3, 4,  //
1620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  });
1630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetBias({-1});
1640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.Invoke();
1650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 30, -24,  //
1670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 40, -34,  //
1680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                             }));
1690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
171e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, HandCalculatedFloat32) {
1720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int depth = 1;
1730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_width = 4;
1740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_height = 3;
1750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_batch_count = 1;
1760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int filter_size = 3;
1770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int filter_count = 1;
1780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int stride_width = 1;
1790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int stride_height = 1;
1800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const Padding padding = Padding_SAME;
1810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  ConvolutionOpModel m(
182e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling      GetRegistration(),
1830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32,
1840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle       {image_batch_count, image_height, image_width, depth}},
1850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
1860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32, {}}, stride_width, stride_height, padding);
1870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The image matrix is:
1890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  1 |  2 |  3 |  4 |
1900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  5 |  6 |  7 |  8 |
1910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  9 | 10 | 11 | 12 |
1920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
1930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The filter matrix is:
1940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 1 | 4 | 7 |
1950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 2 | 5 | 8 |
1960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 3 | 6 | 9 |
1970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
1980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // No bias for this test.
1990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetBias({0});
2000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.Invoke();
2020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
2030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // the input set to zero because we're using the 'SAME' padding mode.
2040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The calculations behind the expected output are:
2050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
2060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
2070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
2080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
2090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
2100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
2110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
2120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
2130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
2140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
2150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
2160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
2170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // This means we should end up with this matrix:
2180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  105  |  150  |  183  |   95  |
2190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  235  |  312  |  357  |  178  |
2200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  187  |  234  |  261  |  121  |
2210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357,
2220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                               178, 187, 234, 261, 121}));
2230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
2240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
225e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, HandCalculatedWithBiasFloat32) {
2260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int depth = 1;
2270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_width = 4;
2280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_height = 3;
2290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_batch_count = 1;
2300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int filter_size = 3;
2310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int filter_count = 1;
2320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int stride_width = 1;
2330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int stride_height = 1;
2340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const Padding padding = Padding_SAME;
2350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  ConvolutionOpModel m(
236e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling      GetRegistration(),
2370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32,
2380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle       {image_batch_count, image_height, image_width, depth}},
2390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
2400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32, {}}, stride_width, stride_height, padding);
2410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The image matrix is:
2430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  1 |  2 |  3 |  4 |
2440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  5 |  6 |  7 |  8 |
2450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  9 | 10 | 11 | 12 |
2460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
2470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The filter matrix is:
2480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 1 | 4 | 7 |
2490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 2 | 5 | 8 |
2500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 3 | 6 | 9 |
2510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
2520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Bias is | 10 |.
2530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetBias({10});
2540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.Invoke();
2560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
2570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // the input set to zero because we're using the 'SAME' padding mode.
2580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The calculations behind the expected output are:
2590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)+10=115
2600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)+10=160
2610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)+10=193
2620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)+10=105
2630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)+10=245
2640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)+10=322
2650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)+10=367
2660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)+10=188
2670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)+10=197
2680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)+10=244
2690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)+10=271
2700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)+10=131
2710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // This means we should end up with this matrix:
2720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  115  |  160  |  193  |  105  |
2730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  245  |  322  |  367  |  188  |
2740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  197  |  244  |  271  |  131  |
2750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(m.GetOutput(), ElementsAreArray({115, 160, 193, 105, 245, 322,
2760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                               367, 188, 197, 244, 271, 131}));
2770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
2780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
279e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, HandCalculatedWithReluFloat32) {
2800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int depth = 1;
2810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_width = 4;
2820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_height = 3;
2830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_batch_count = 1;
2840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int filter_size = 3;
2850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int filter_count = 1;
2860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int stride_width = 1;
2870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int stride_height = 1;
2880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const Padding padding = Padding_SAME;
2890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  ConvolutionOpModel m(
290e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling      GetRegistration(),
2910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32,
2920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle       {image_batch_count, image_height, image_width, depth}},
2930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
2940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32, {}}, stride_width, stride_height, padding,
2950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      ActivationFunctionType_RELU);
2960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The image matrix is:
2980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  1 |  2 |  3 |  4 |
2990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  5 |  6 |  7 |  8 |
3000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  9 | 10 | 11 | 12 |
3010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
3020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The filter matrix is:
3030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 1 | 4 | 7 |
3040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 2 | 5 | 8 |
3050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 3 | 6 | 9 |
3060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
3070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Bias is | -200 |.
3080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetBias({-200});
3090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.Invoke();
3110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
3120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // the input set to zero because we're using the 'SAME' padding mode.
3130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The calculations behind the expected output are:
3140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)-200=-95
3150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)-200=-50
3160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)-200=-17
3170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)-200=-105
3180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)-200=35
3190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)-200=112
3200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)-200=157
3210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)-200=-22
3220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)-200=-13
3230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)-200=34
3240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)-200=61
3250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)-200=-79
3260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // All negative values are gated to zero by the Relu activation function.
3270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // This means we should end up with this matrix:
3280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |   0 |   0 |   0 |   0 |
3290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  35 | 112 | 157 |   0 |
3300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |   0 |  34 |  61 |   0 |
3310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(m.GetOutput(),
3320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0}));
3330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
3340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
335e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, HandCalculatedValidFloat32) {
3360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int depth = 1;
3370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_width = 4;
3380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_height = 3;
3390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int image_batch_count = 1;
3400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int filter_size = 3;
3410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int filter_count = 1;
3420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int stride_width = 1;
3430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int stride_height = 1;
3440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const Padding padding = Padding_VALID;
3450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  ConvolutionOpModel m(
346e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling      GetRegistration(),
3470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32,
3480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle       {image_batch_count, image_height, image_width, depth}},
3490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
3500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {TensorType_FLOAT32, {}}, stride_width, stride_height, padding);
3510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The image matrix is:
3530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  1 |  2 |  3 |  4 |
3540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  5 |  6 |  7 |  8 |
3550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  9 | 10 | 11 | 12 |
3560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
3570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The filter matrix is:
3580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 1 | 4 | 7 |
3590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 2 | 5 | 8 |
3600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // | 3 | 6 | 9 |
3610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
3620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // No bias for this test.
3630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetBias({0});
3640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.Invoke();
3660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // We're sliding the 3x3 filter across the 3x4 image, with no accesses outside
3670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // the input because we're using the 'VALID' padding mode, giving a 2x1
3680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // output.
3690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // The calculations behind the expected output are:
3700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
3710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
3720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // This means we should end up with this matrix:
3730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // |  312  |  357  |
3740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357}));
3750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
3760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleclass QuantizedConvolutionOpModel : public BaseConvolutionOpModel {
3780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle public:
3790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  using BaseConvolutionOpModel::BaseConvolutionOpModel;
3800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  void SetInput(std::initializer_list<float> data) {
3820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    QuantizeAndPopulate<uint8_t>(input_, data);
3830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
3840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  void SetFilter(std::initializer_list<float> data) {
3860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    QuantizeAndPopulate<uint8_t>(filter_, data);
3870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
3880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  void SetBias(std::initializer_list<float> data) {
3900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    QuantizeAndPopulate<int32_t>(bias_, data);
3910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
3920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
3940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  std::vector<float> GetDequantizedOutput() {
3950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
3960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                               GetScale(output_), GetZeroPoint(output_));
3970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
3980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle};
3990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
4000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// In this tests we set the input and output scales so that the results
4010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// match exactly the 'non-quantized' version.
402e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, SimpleTestQuantized) {
403e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling  QuantizedConvolutionOpModel m(GetRegistration(),
404e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling                                {TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
4050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
4060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                {TensorType_UINT8, {}, -127, 128});
4070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetInput({
4080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // First batch
4090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      1, 1, 1, 1,  // row = 1
4100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      2, 2, 2, 2,  // row = 2
4110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // Second batch
4120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      1, 2, 3, 4,  // row = 1
4130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      1, 2, 3, 4,  // row = 2
4140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  });
4150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetFilter({
4160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      1, 2, 3, 4,    // first 2x2 filter
4170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      -1, 1, -1, 1,  // second 2x2 filter
4180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      -1, -1, 1, 1,  // third 2x2 filter
4190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  });
4200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetBias({1, 2, 3});
4210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
4220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.Invoke();
4230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
4240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(m.GetDequantizedOutput(),
4250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              ElementsAreArray(ArrayFloatNear(
4260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                  {
4270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                      18, 2, 5,  // first batch, left
4280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                      18, 2, 5,  // first batch, right
4290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                      17, 4, 3,  // second batch, left
4300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                      37, 4, 3,  // second batch, right
4310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                  },
4320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                  1e-5)));
4330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // For good  measure, let's also verify the quantized values:
4340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
4350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 145, 129, 132,  //
4360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 145, 129, 132,  //
4370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 144, 131, 130,  //
4380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 164, 131, 130,  //
4390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                             }));
4400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
4410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
442e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) {
443e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling  QuantizedConvolutionOpModel m(GetRegistration(),
444e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling                                {TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64},
4450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64},
4460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                {TensorType_UINT8, {}, -127, 128},
4470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                /*stride_width=*/3, /*stride_height=*/1);
4480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetInput({
4490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      3, 2, 1, -1, -2, -3,  //
4500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      4, 3, 2, -2, -3, -4,  //
4510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      5, 4, 3, -3, -4, -5,  //
4520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  });
4530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetFilter({
4540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      1, 2,  //
4550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      3, 4,  //
4560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  });
4570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.SetBias({-1});
4580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  m.Invoke();
4590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
4600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                            30, -24,  //
4610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                            40, -34,  //
4620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                        })));
4630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
4640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 157, 103,  //
4650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 167, 93,   //
4660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                             }));
4670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
468e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling
469e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingINSTANTIATE_TEST_CASE_P(
470e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling    ConvolutionOpTest, ConvolutionOpTest,
471e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling    ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
472e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling
4730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace
4740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace tflite
4750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
4760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleint main(int argc, char** argv) {
47700791693e4d32bed92fcfadf09da321c9f548babA. Unique TensorFlower  ::tflite::LogToStderr();
4780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  ::testing::InitGoogleTest(&argc, argv);
4790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return RUN_ALL_TESTS();
4800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
481