1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <vector>
17
18#include <gtest/gtest.h>
19#include "tensorflow/contrib/lite/interpreter.h"
20#include "tensorflow/contrib/lite/kernels/register.h"
21#include "tensorflow/contrib/lite/kernels/test_util.h"
22#include "tensorflow/contrib/lite/model.h"
23#include "tensorflow/contrib/lite/string_util.h"
24
25namespace tflite {
26
27namespace ops {
28namespace custom {
29TfLiteRegistration* Register_PREDICT();
30
31namespace {
32
33using ::testing::ElementsAreArray;
34
35class PredictOpModel : public SingleOpModel {
36 public:
37  PredictOpModel(std::initializer_list<int> input_signature_shape,
38                 std::initializer_list<int> key_shape,
39                 std::initializer_list<int> labelweight_shape, int num_output,
40                 float threshold) {
41    input_signature_ = AddInput(TensorType_INT32);
42    model_key_ = AddInput(TensorType_INT32);
43    model_label_ = AddInput(TensorType_INT32);
44    model_weight_ = AddInput(TensorType_FLOAT32);
45    output_label_ = AddOutput(TensorType_INT32);
46    output_weight_ = AddOutput(TensorType_FLOAT32);
47
48    std::vector<uint8_t> predict_option;
49    writeInt32(num_output, &predict_option);
50    writeFloat32(threshold, &predict_option);
51    SetCustomOp("Predict", predict_option, Register_PREDICT);
52    BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape,
53                       labelweight_shape}});
54  }
55
56  void SetInputSignature(std::initializer_list<int> data) {
57    PopulateTensor<int>(input_signature_, data);
58  }
59
60  void SetModelKey(std::initializer_list<int> data) {
61    PopulateTensor<int>(model_key_, data);
62  }
63
64  void SetModelLabel(std::initializer_list<int> data) {
65    PopulateTensor<int>(model_label_, data);
66  }
67
68  void SetModelWeight(std::initializer_list<float> data) {
69    PopulateTensor<float>(model_weight_, data);
70  }
71
72  std::vector<int> GetLabel() { return ExtractVector<int>(output_label_); }
73  std::vector<float> GetWeight() {
74    return ExtractVector<float>(output_weight_);
75  }
76
77  void writeFloat32(float value, std::vector<uint8_t>* data) {
78    union {
79      float v;
80      uint8_t r[4];
81    } float_to_raw;
82    float_to_raw.v = value;
83    for (unsigned char i : float_to_raw.r) {
84      data->push_back(i);
85    }
86  }
87
88  void writeInt32(int32_t value, std::vector<uint8_t>* data) {
89    union {
90      int32_t v;
91      uint8_t r[4];
92    } int32_to_raw;
93    int32_to_raw.v = value;
94    for (unsigned char i : int32_to_raw.r) {
95      data->push_back(i);
96    }
97  }
98
99 private:
100  int input_signature_;
101  int model_key_;
102  int model_label_;
103  int model_weight_;
104  int output_label_;
105  int output_weight_;
106};
107
108TEST(PredictOpTest, AllLabelsAreValid) {
109  PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
110  m.SetInputSignature({1, 3, 7, 9});
111  m.SetModelKey({1, 2, 4, 6, 7});
112  m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
113  m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
114  m.Invoke();
115  EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11}));
116  EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05})));
117}
118
119TEST(PredictOpTest, MoreLabelsThanRequired) {
120  PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001);
121  m.SetInputSignature({1, 3, 7, 9});
122  m.SetModelKey({1, 2, 4, 6, 7});
123  m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
124  m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
125  m.Invoke();
126  EXPECT_THAT(m.GetLabel(), ElementsAreArray({12}));
127  EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1})));
128}
129
130TEST(PredictOpTest, OneLabelDoesNotPassThreshold) {
131  PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07);
132  m.SetInputSignature({1, 3, 7, 9});
133  m.SetModelKey({1, 2, 4, 6, 7});
134  m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
135  m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
136  m.Invoke();
137  EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1}));
138  EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0})));
139}
140
141TEST(PredictOpTest, NoneLabelPassThreshold) {
142  PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6);
143  m.SetInputSignature({1, 3, 7, 9});
144  m.SetModelKey({1, 2, 4, 6, 7});
145  m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
146  m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
147  m.Invoke();
148  EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
149  EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
150}
151
152TEST(PredictOpTest, OnlyOneLabelGenerated) {
153  PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
154  m.SetInputSignature({1, 3, 7, 9});
155  m.SetModelKey({1, 2, 4, 6, 7});
156  m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0});
157  m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0});
158  m.Invoke();
159  EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1}));
160  EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0})));
161}
162
163TEST(PredictOpTest, NoLabelGenerated) {
164  PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
165  m.SetInputSignature({5, 3, 7, 9});
166  m.SetModelKey({1, 2, 4, 6, 7});
167  m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0});
168  m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0});
169  m.Invoke();
170  EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
171  EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
172}
173
174}  // namespace
175}  // namespace custom
176}  // namespace ops
177}  // namespace tflite
178
179int main(int argc, char** argv) {
180  // On Linux, add: tflite::LogToStderr();
181  ::testing::InitGoogleTest(&argc, argv);
182  return RUN_ALL_TESTS();
183}
184