1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15// Unit test for TFLite SOFTMAX op.
16
17#include <iomanip>
18#include <memory>
19#include <vector>
20
21#include <gmock/gmock.h>
22#include <gtest/gtest.h>
23#include "tensorflow/contrib/lite/interpreter.h"
24#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
25#include "tensorflow/contrib/lite/kernels/register.h"
26#include "tensorflow/contrib/lite/kernels/test_util.h"
27#include "tensorflow/contrib/lite/model.h"
28
29namespace tflite {
30namespace {
31
32class SoftmaxOpModel : public SingleOpModel {
33 public:
34  SoftmaxOpModel(int batches, int size, float beta)
35      : batches_(batches), input_size_(size), beta_(beta) {
36    input_ = AddInput(TensorType_FLOAT32);
37    output_ = AddOutput(TensorType_FLOAT32);
38    SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
39                 CreateSoftmaxOptions(builder_, beta_).Union());
40    BuildInterpreter({{batches_, input_size_}});
41  }
42
43  void SetInput(std::initializer_list<float> data) {
44    PopulateTensor(input_, data);
45  }
46
47  void SetInput(int offset, float* begin, float* end) {
48    PopulateTensor(input_, offset, begin, end);
49  }
50
51  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
52
53 private:
54  int input_;
55  int output_;
56
57  int batches_;
58  int input_size_;
59  float beta_;
60};
61
62TEST(SoftmaxOpTest, SimpleTest) {
63  SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0);
64  m.SetInput({
65      1.0, 2.0, 3.0, 4.0, 5.0,       // b = 0
66      -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 0
67  });
68
69  m.Invoke();
70
71  EXPECT_THAT(
72      m.GetOutput(),
73      ElementsAreArray(ArrayFloatNear(
74          {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
75           0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231},
76          1e-6)));
77}
78
79TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) {
80  const int batch_size = 2;
81  const int input_size = 5;
82  const float beta = 1.0;
83  static float input_buffer[] = {
84      1.0,  2.0,  3.0,  4.0,  5.0,   // b = 0
85      -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 1
86  };
87
88  SoftmaxOpModel m(batch_size, input_size, beta);
89
90  m.SetInput(0, input_buffer, input_buffer + input_size * batch_size);
91
92  m.Invoke();
93
94  std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
95  static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
96                                       {1, 0, 0, input_size}};
97  tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
98                                 output_buffer.get(), input_dims);
99
100  std::vector<float> expected;
101  expected.insert(expected.end(), output_buffer.get(),
102                  output_buffer.get() + input_size * batch_size);
103
104  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6)));
105}
106
107TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
108  const int batch_size = 2;
109  const int input_size = 5;
110  const float beta = 0.5;
111  static float input_buffer[] = {
112      1.0,  2.0,  3.0,  4.0,  5.0,   // b = 0
113      -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 1
114  };
115
116  SoftmaxOpModel m(batch_size, input_size, beta);
117
118  m.SetInput(0, input_buffer, input_buffer + input_size * batch_size);
119
120  m.Invoke();
121
122  std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
123  static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
124                                       {1, 0, 0, input_size}};
125  tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
126                                 output_buffer.get(), input_dims);
127
128  std::vector<float> expected;
129  expected.insert(expected.end(), output_buffer.get(),
130                  output_buffer.get() + input_size * batch_size);
131
132  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6)));
133}
134
135}  // namespace
136}  // namespace tflite
137
138int main(int argc, char** argv) {
139  ::tflite::LogToStderr();
140  ::testing::InitGoogleTest(&argc, argv);
141  return RUN_ALL_TESTS();
142}
143