1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <gtest/gtest.h>
16#include "tensorflow/contrib/lite/interpreter.h"
17#include "tensorflow/contrib/lite/kernels/register.h"
18#include "tensorflow/contrib/lite/kernels/test_util.h"
19#include "tensorflow/contrib/lite/model.h"
20
21namespace tflite {
22namespace {
23
24using ::testing::ElementsAreArray;
25
26class PadOpModel : public SingleOpModel {
27 public:
28  void SetInput(std::initializer_list<float> data) {
29    PopulateTensor<float>(input_, data);
30  }
31
32  void SetPaddings(std::initializer_list<int> paddings) {
33    PopulateTensor<int>(paddings_, paddings);
34  }
35
36  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
37  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
38
39 protected:
40  int input_;
41  int output_;
42  int paddings_;
43};
44
45// Tests case where paddings is a const tensor.
46//
47// Example usage is as follows:
48//    PadOpDynamicModel m(input_shape, paddings_shape, paddings_data);
49//    m.SetInput(input_data);
50//    m.Invoke();
51class PadOpConstModel : public PadOpModel {
52 public:
53  PadOpConstModel(std::initializer_list<int> input_shape,
54                  std::initializer_list<int> paddings_shape,
55                  std::initializer_list<int> paddings) {
56    input_ = AddInput(TensorType_FLOAT32);
57    paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape);
58    output_ = AddOutput(TensorType_FLOAT32);
59
60    SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
61                 CreatePadOptions(builder_).Union());
62    BuildInterpreter({input_shape});
63  }
64};
65
66// Test case where paddings is a non-const tensor.
67//
68// Example usage is as follows:
69//    PadOpDynamicModel m(input_shape, paddings_shape);
70//    m.SetInput(input_data);
71//    m.SetPaddings(paddings_data);
72//    m.Invoke();
73class PadOpDynamicModel : public PadOpModel {
74 public:
75  PadOpDynamicModel(std::initializer_list<int> input_shape,
76                    std::initializer_list<int> paddings_shape) {
77    input_ = AddInput(TensorType_FLOAT32);
78    paddings_ = AddInput(TensorType_INT32);
79    output_ = AddOutput(TensorType_FLOAT32);
80
81    SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
82                 CreatePadOptions(builder_).Union());
83    BuildInterpreter({input_shape, paddings_shape});
84  }
85};
86
87TEST(PadOpTest, TooManyDimensions) {
88  EXPECT_DEATH(
89      PadOpConstModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9, 2},
90                      {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}),
91      "dims != 4");
92}
93
94TEST(PadOpTest, UnequalDimensions) {
95  EXPECT_DEATH(PadOpConstModel({1, 1, 2, 1}, {3, 2}, {1, 1, 2, 2, 3, 3}),
96               "3 != 4");
97}
98
99TEST(PadOpTest, InvalidPadValue) {
100  EXPECT_DEATH(
101      PadOpConstModel({1, 1, 2, 1}, {4, 2}, {0, 0, 1, -1, 2, -1, 0, 0}),
102      "Pad value has to be greater than equal to 0.");
103}
104
105TEST(PadOpTest, SimpleConstTest) {
106  // Padding is represented as four 2-D lists representing above padding and
107  // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
108  PadOpConstModel m({1, 2, 2, 1}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0});
109  m.SetInput({1, 2, 3, 4});
110  m.Invoke();
111  EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
112                                               0, 0, 0, 0, 0}));
113  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
114}
115
116TEST(PadOpTest, SimpleDynamicTest) {
117  PadOpDynamicModel m({1, 2, 2, 1}, {4, 2});
118  m.SetInput({1, 2, 3, 4});
119  m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
120  m.Invoke();
121  EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
122                                               0, 0, 0, 0, 0}));
123  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
124}
125
126TEST(PadOpTest, AdvancedConstTest) {
127  PadOpConstModel m({1, 2, 3, 1}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0});
128  m.SetInput({1, 2, 3, 4, 5, 6});
129  m.Invoke();
130  EXPECT_THAT(m.GetOutput(),
131              ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
132                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
133  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
134}
135
136TEST(PadOpTest, AdvancedDynamicTest) {
137  PadOpDynamicModel m({1, 2, 3, 1}, {4, 2});
138  m.SetInput({1, 2, 3, 4, 5, 6});
139  m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
140  m.Invoke();
141  EXPECT_THAT(m.GetOutput(),
142              ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
143                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
144  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
145}
146
147}  // namespace
148}  // namespace tflite
149
150int main(int argc, char** argv) {
151  ::tflite::LogToStderr();
152  ::testing::InitGoogleTest(&argc, argv);
153  return RUN_ALL_TESTS();
154}
155