10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License"); 40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License. 50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at 60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle http://www.apache.org/licenses/LICENSE-2.0 80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software 100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS, 110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and 130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License. 140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/ 150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <cstdarg> 160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <gtest/gtest.h> 1814e0e7fe1eafd286f3813ba839b5f3236394a0a1Anna R#include "absl/memory/memory.h" 190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/interpreter.h" 200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/kernels/register.h" 210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/kernels/test_util.h" 220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/model.h" 230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace tflite { 25e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling 26e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Lingnamespace ops { 27e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Lingnamespace builtin { 28e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling 29e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTfLiteRegistration* Register_CONVOLUTION_REF(); 30e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT(); 31e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT(); 321a92f45677ee66af24f2219c6b1cbaeee87056b7Yu-Cheng LingTfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT(); 33e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling 34e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling} // namespace builtin 35e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling} // namespace ops 36e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling 370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace { 380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleusing ::testing::ElementsAreArray; 400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleclass BaseConvolutionOpModel : public SingleOpModel { 420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle public: 430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // TODO(ahentz): Also test different activation types, bias, padding types, 440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // stride values. 450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle BaseConvolutionOpModel( 46e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling TfLiteRegistration* registration, const TensorData& input, 47e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling const TensorData& filter, const TensorData& output, int stride_width = 2, 48e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling int stride_height = 2, enum Padding padding = Padding_VALID, 490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle enum ActivationFunctionType activation = ActivationFunctionType_NONE) { 500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle input_ = AddInput(input); 510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle filter_ = AddInput(filter); 520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle int bias_size = GetShape(filter_)[0]; 540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (input.type == TensorType_FLOAT32) { 550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); 560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } else { 570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // This is a quantized version. The scale of 'bias' depends on the scales 580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // of input and filter. Supposedly this is correctly set during quantized 590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // training. 600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto bias_scale = GetScale(input_) * GetScale(filter_); 610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; 620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle bias_ = AddInput(bias); 630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle output_ = AddOutput(output); 660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (input.type != TensorType_FLOAT32) { 670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The following is required by quantized inference. It is the unittest's 680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // responsibility to make sure the output scale falls into the correct 690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // range. 700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); 710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, 740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CreateConv2DOptions(builder_, padding, stride_width, 750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle stride_height, activation) 760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle .Union()); 770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 78e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling resolver_ = absl::make_unique<SingleOpResolver>(BuiltinOperator_CONV_2D, 79e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling registration); 800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); 810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle protected: 840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle int input_; 850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle int filter_; 860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle int bias_; 870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle int output_; 880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}; 890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleclass ConvolutionOpModel : public BaseConvolutionOpModel { 910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle public: 920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle using BaseConvolutionOpModel::BaseConvolutionOpModel; 930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); } 950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); } 970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle void SetInput(std::initializer_list<float> data) { 990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle PopulateTensor(input_, data); 1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle std::vector<float> GetOutput() { return ExtractVector<float>(output_); } 1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}; 1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 104e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Lingconst auto kKernelMap = new std::map<string, TfLiteRegistration*>({ 105e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling {"Reference", ops::builtin::Register_CONVOLUTION_REF()}, 106e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling {"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()}, 107e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling {"MultithreadedOptimized", 108e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()}, 1091a92f45677ee66af24f2219c6b1cbaeee87056b7Yu-Cheng Ling {"CblasOptimized", ops::builtin::Register_CONVOLUTION_CBLAS_OPT()}, 110e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling}); 111e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling 112e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Lingclass ConvolutionOpTest : public SingleOpTest { 113e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling protected: 114e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling const std::map<string, TfLiteRegistration*>& GetKernelMap() override { 115e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling return *kKernelMap; 116e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling } 117e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling}; 118e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling 119e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, SimpleTestFloat32) { 120e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, 1210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {3, 2, 2, 1}}, 1220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {}}); 1230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetInput({ 1250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // First batch 1260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1, 1, 1, 1, // row = 1 1270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2, 2, 2, 2, // row = 2 1280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Second batch 1290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1, 2, 3, 4, // row = 1 1300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1, 2, 3, 4, // row = 2 1310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }); 1320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetFilter({ 1330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1, 2, 3, 4, // first 2x2 filter 1340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle -1, 1, -1, 1, // second 2x2 filter 1350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle -1, -1, 1, 1, // third 2x2 filter 1360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }); 1370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetBias({1, 2, 3}); 1380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.Invoke(); 1400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle EXPECT_THAT(m.GetOutput(), ElementsAreArray({ 1420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 18, 2, 5, // first batch, left 1430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 18, 2, 5, // first batch, right 1440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 17, 4, 3, // second batch, left 1450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 37, 4, 3, // second batch, right 1460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle })); 1470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 1480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 149e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { 150e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}}, 1510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {1, 2, 2, 1}}, 1520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {}}, 1530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle /*stride_width=*/3, /*stride_height=*/1); 1540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetInput({ 1550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3, 2, 1, -1, -2, -3, // 1560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 4, 3, 2, -2, -3, -4, // 1570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 5, 4, 3, -3, -4, -5, // 1580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }); 1590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetFilter({ 1600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1, 2, // 1610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3, 4, // 1620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }); 1630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetBias({-1}); 1640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.Invoke(); 1650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle EXPECT_THAT(m.GetOutput(), ElementsAreArray({ 1660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 30, -24, // 1670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 40, -34, // 1680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle })); 1690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 1700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 171e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, HandCalculatedFloat32) { 1720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int depth = 1; 1730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_width = 4; 1740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_height = 3; 1750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_batch_count = 1; 1760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int filter_size = 3; 1770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int filter_count = 1; 1780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int stride_width = 1; 1790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int stride_height = 1; 1800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const Padding padding = Padding_SAME; 1810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle ConvolutionOpModel m( 182e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling GetRegistration(), 1830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, 1840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {image_batch_count, image_height, image_width, depth}}, 1850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, 1860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); 1870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The image matrix is: 1890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 1 | 2 | 3 | 4 | 1900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 5 | 6 | 7 | 8 | 1910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 9 | 10 | 11 | 12 | 1920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); 1930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The filter matrix is: 1940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 1 | 4 | 7 | 1950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 2 | 5 | 8 | 1960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 3 | 6 | 9 | 1970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); 1980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // No bias for this test. 1990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetBias({0}); 2000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.Invoke(); 2020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // We're sliding the 3x3 filter across the 3x4 image, with accesses outside 2030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // the input set to zero because we're using the 'SAME' padding mode. 2040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The calculations behind the expected output are: 2050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105 2060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150 2070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183 2080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95 2090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235 2100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 2110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 2120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178 2130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187 2140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234 2150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261 2160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121 2170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // This means we should end up with this matrix: 2180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 105 | 150 | 183 | 95 | 2190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 235 | 312 | 357 | 178 | 2200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 187 | 234 | 261 | 121 | 2210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357, 2220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 178, 187, 234, 261, 121})); 2230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 2240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 225e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { 2260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int depth = 1; 2270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_width = 4; 2280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_height = 3; 2290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_batch_count = 1; 2300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int filter_size = 3; 2310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int filter_count = 1; 2320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int stride_width = 1; 2330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int stride_height = 1; 2340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const Padding padding = Padding_SAME; 2350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle ConvolutionOpModel m( 236e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling GetRegistration(), 2370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, 2380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {image_batch_count, image_height, image_width, depth}}, 2390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, 2400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); 2410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The image matrix is: 2430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 1 | 2 | 3 | 4 | 2440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 5 | 6 | 7 | 8 | 2450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 9 | 10 | 11 | 12 | 2460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); 2470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The filter matrix is: 2480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 1 | 4 | 7 | 2490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 2 | 5 | 8 | 2500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 3 | 6 | 9 | 2510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); 2520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Bias is | 10 |. 2530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetBias({10}); 2540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.Invoke(); 2560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // We're sliding the 3x3 filter across the 3x4 image, with accesses outside 2570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // the input set to zero because we're using the 'SAME' padding mode. 2580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The calculations behind the expected output are: 2590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)+10=115 2600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)+10=160 2610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)+10=193 2620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)+10=105 2630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)+10=245 2640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)+10=322 2650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)+10=367 2660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)+10=188 2670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)+10=197 2680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)+10=244 2690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)+10=271 2700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)+10=131 2710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // This means we should end up with this matrix: 2720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 115 | 160 | 193 | 105 | 2730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 245 | 322 | 367 | 188 | 2740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 197 | 244 | 271 | 131 | 2750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle EXPECT_THAT(m.GetOutput(), ElementsAreArray({115, 160, 193, 105, 245, 322, 2760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 367, 188, 197, 244, 271, 131})); 2770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 2780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 279e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, HandCalculatedWithReluFloat32) { 2800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int depth = 1; 2810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_width = 4; 2820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_height = 3; 2830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_batch_count = 1; 2840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int filter_size = 3; 2850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int filter_count = 1; 2860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int stride_width = 1; 2870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int stride_height = 1; 2880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const Padding padding = Padding_SAME; 2890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle ConvolutionOpModel m( 290e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling GetRegistration(), 2910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, 2920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {image_batch_count, image_height, image_width, depth}}, 2930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, 2940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {}}, stride_width, stride_height, padding, 2950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle ActivationFunctionType_RELU); 2960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The image matrix is: 2980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 1 | 2 | 3 | 4 | 2990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 5 | 6 | 7 | 8 | 3000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 9 | 10 | 11 | 12 | 3010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); 3020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The filter matrix is: 3030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 1 | 4 | 7 | 3040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 2 | 5 | 8 | 3050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 3 | 6 | 9 | 3060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); 3070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Bias is | -200 |. 3080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetBias({-200}); 3090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.Invoke(); 3110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // We're sliding the 3x3 filter across the 3x4 image, with accesses outside 3120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // the input set to zero because we're using the 'SAME' padding mode. 3130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The calculations behind the expected output are: 3140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)-200=-95 3150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)-200=-50 3160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)-200=-17 3170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)-200=-105 3180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)-200=35 3190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)-200=112 3200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)-200=157 3210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)-200=-22 3220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)-200=-13 3230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)-200=34 3240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)-200=61 3250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)-200=-79 3260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // All negative values are gated to zero by the Relu activation function. 3270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // This means we should end up with this matrix: 3280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 0 | 0 | 0 | 0 | 3290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 35 | 112 | 157 | 0 | 3300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 0 | 34 | 61 | 0 | 3310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle EXPECT_THAT(m.GetOutput(), 3320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0})); 3330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 3340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 335e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, HandCalculatedValidFloat32) { 3360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int depth = 1; 3370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_width = 4; 3380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_height = 3; 3390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int image_batch_count = 1; 3400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int filter_size = 3; 3410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int filter_count = 1; 3420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int stride_width = 1; 3430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int stride_height = 1; 3440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const Padding padding = Padding_VALID; 3450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle ConvolutionOpModel m( 346e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling GetRegistration(), 3470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, 3480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {image_batch_count, image_height, image_width, depth}}, 3490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, 3500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); 3510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The image matrix is: 3530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 1 | 2 | 3 | 4 | 3540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 5 | 6 | 7 | 8 | 3550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 9 | 10 | 11 | 12 | 3560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); 3570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The filter matrix is: 3580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 1 | 4 | 7 | 3590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 2 | 5 | 8 | 3600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 3 | 6 | 9 | 3610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); 3620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // No bias for this test. 3630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetBias({0}); 3640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.Invoke(); 3660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // We're sliding the 3x3 filter across the 3x4 image, with no accesses outside 3670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // the input because we're using the 'VALID' padding mode, giving a 2x1 3680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // output. 3690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // The calculations behind the expected output are: 3700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 3710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 3720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // This means we should end up with this matrix: 3730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // | 312 | 357 | 3740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357})); 3750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 3760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleclass QuantizedConvolutionOpModel : public BaseConvolutionOpModel { 3780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle public: 3790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle using BaseConvolutionOpModel::BaseConvolutionOpModel; 3800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle void SetInput(std::initializer_list<float> data) { 3820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle QuantizeAndPopulate<uint8_t>(input_, data); 3830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 3840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle void SetFilter(std::initializer_list<float> data) { 3860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle QuantizeAndPopulate<uint8_t>(filter_, data); 3870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 3880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle void SetBias(std::initializer_list<float> data) { 3900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle QuantizeAndPopulate<int32_t>(bias_, data); 3910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 3920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); } 3940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle std::vector<float> GetDequantizedOutput() { 3950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_), 3960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle GetScale(output_), GetZeroPoint(output_)); 3970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 3980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}; 3990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 4000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// In this tests we set the input and output scales so that the results 4010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// match exactly the 'non-quantized' version. 402e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, SimpleTestQuantized) { 403e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling QuantizedConvolutionOpModel m(GetRegistration(), 404e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling {TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, 4050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, 4060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_UINT8, {}, -127, 128}); 4070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetInput({ 4080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // First batch 4090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1, 1, 1, 1, // row = 1 4100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2, 2, 2, 2, // row = 2 4110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Second batch 4120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1, 2, 3, 4, // row = 1 4130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1, 2, 3, 4, // row = 2 4140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }); 4150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetFilter({ 4160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1, 2, 3, 4, // first 2x2 filter 4170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle -1, 1, -1, 1, // second 2x2 filter 4180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle -1, -1, 1, 1, // third 2x2 filter 4190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }); 4200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetBias({1, 2, 3}); 4210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 4220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.Invoke(); 4230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 4240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle EXPECT_THAT(m.GetDequantizedOutput(), 4250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle ElementsAreArray(ArrayFloatNear( 4260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle { 4270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 18, 2, 5, // first batch, left 4280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 18, 2, 5, // first batch, right 4290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 17, 4, 3, // second batch, left 4300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 37, 4, 3, // second batch, right 4310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }, 4320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1e-5))); 4330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // For good measure, let's also verify the quantized values: 4340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle EXPECT_THAT(m.GetOutput(), ElementsAreArray({ 4350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 145, 129, 132, // 4360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 145, 129, 132, // 4370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 144, 131, 130, // 4380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 164, 131, 130, // 4390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle })); 4400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 4410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 442e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingTEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { 443e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling QuantizedConvolutionOpModel m(GetRegistration(), 444e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling {TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, 4450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64}, 4460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle {TensorType_UINT8, {}, -127, 128}, 4470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle /*stride_width=*/3, /*stride_height=*/1); 4480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetInput({ 4490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3, 2, 1, -1, -2, -3, // 4500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 4, 3, 2, -2, -3, -4, // 4510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 5, 4, 3, -3, -4, -5, // 4520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }); 4530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetFilter({ 4540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1, 2, // 4550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3, 4, // 4560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }); 4570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.SetBias({-1}); 4580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle m.Invoke(); 4590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ 4600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 30, -24, // 4610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 40, -34, // 4620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }))); 4630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle EXPECT_THAT(m.GetOutput(), ElementsAreArray({ 4640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 157, 103, // 4650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 167, 93, // 4660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle })); 4670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 468e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling 469e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng LingINSTANTIATE_TEST_CASE_P( 470e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling ConvolutionOpTest, ConvolutionOpTest, 471e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); 472e8e33b0050e7e1ff686312bcbdafa270c2e29462Yu-Cheng Ling 4730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace 4740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace tflite 4750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 4760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleint main(int argc, char** argv) { 47700791693e4d32bed92fcfadf09da321c9f548babA. Unique TensorFlower ::tflite::LogToStderr(); 4780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle ::testing::InitGoogleTest(&argc, argv); 4790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return RUN_ALL_TESTS(); 4800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 481